feat: AI聊天室多Agent协作讨论平台

- 实现Agent管理,支持AI辅助生成系统提示词
- 支持多个AI提供商(OpenRouter、智谱、MiniMax等)
- 实现聊天室和讨论引擎
- WebSocket实时消息推送
- 前端使用React + Ant Design
- 后端使用FastAPI + MongoDB

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Claude Code
2026-02-03 19:20:02 +08:00
commit edbddf855d
76 changed files with 14681 additions and 0 deletions

16
backend/.env.example Normal file
View File

@@ -0,0 +1,16 @@
# MongoDB配置
MONGODB_URL=mongodb://localhost:27017
MONGODB_DB=ai_chatroom
# 服务配置
HOST=0.0.0.0
PORT=8000
DEBUG=true
# 安全配置(生产环境请修改)
SECRET_KEY=your-secret-key-change-in-production
ENCRYPTION_KEY=your-encryption-key-32-bytes-long
# 代理配置(可选)
# DEFAULT_HTTP_PROXY=http://127.0.0.1:7890
# DEFAULT_HTTPS_PROXY=http://127.0.0.1:7890

25
backend/Dockerfile Normal file
View File

@@ -0,0 +1,25 @@
# AI聊天室后端 Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,58 @@
"""
AI接口适配器模块
提供统一的AI调用接口
"""
from .base_adapter import BaseAdapter, AdapterResponse, ChatMessage
from .minimax_adapter import MiniMaxAdapter
from .zhipu_adapter import ZhipuAdapter
from .openrouter_adapter import OpenRouterAdapter
from .kimi_adapter import KimiAdapter
from .deepseek_adapter import DeepSeekAdapter
from .gemini_adapter import GeminiAdapter
from .ollama_adapter import OllamaAdapter
from .llmstudio_adapter import LLMStudioAdapter
__all__ = [
"BaseAdapter",
"AdapterResponse",
"ChatMessage",
"MiniMaxAdapter",
"ZhipuAdapter",
"OpenRouterAdapter",
"KimiAdapter",
"DeepSeekAdapter",
"GeminiAdapter",
"OllamaAdapter",
"LLMStudioAdapter",
]
# 适配器注册表
ADAPTER_REGISTRY = {
"minimax": MiniMaxAdapter,
"zhipu": ZhipuAdapter,
"openrouter": OpenRouterAdapter,
"kimi": KimiAdapter,
"deepseek": DeepSeekAdapter,
"gemini": GeminiAdapter,
"ollama": OllamaAdapter,
"llmstudio": LLMStudioAdapter,
}
def get_adapter(provider_type: str) -> type:
"""
根据提供商类型获取对应的适配器类
Args:
provider_type: 提供商类型标识
Returns:
适配器类
Raises:
ValueError: 未知的提供商类型
"""
adapter_class = ADAPTER_REGISTRY.get(provider_type.lower())
if not adapter_class:
raise ValueError(f"未知的AI提供商类型: {provider_type}")
return adapter_class

View File

@@ -0,0 +1,166 @@
"""
AI适配器基类
定义统一的AI调用接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, AsyncGenerator
from datetime import datetime
@dataclass
class ChatMessage:
"""聊天消息"""
role: str # system, user, assistant
content: str # 消息内容
name: Optional[str] = None # 发送者名称(可选)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
d = {"role": self.role, "content": self.content}
if self.name:
d["name"] = self.name
return d
@dataclass
class AdapterResponse:
"""适配器响应"""
success: bool # 是否成功
content: str = "" # 响应内容
error: Optional[str] = None # 错误信息
# 统计信息
prompt_tokens: int = 0 # 输入token数
completion_tokens: int = 0 # 输出token数
total_tokens: int = 0 # 总token数
# 元数据
model: str = "" # 使用的模型
finish_reason: str = "" # 结束原因
latency_ms: float = 0.0 # 延迟(毫秒)
# 工具调用结果
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
def __post_init__(self):
if self.total_tokens == 0:
self.total_tokens = self.prompt_tokens + self.completion_tokens
class BaseAdapter(ABC):
"""
AI适配器基类
所有AI提供商适配器必须继承此类
"""
def __init__(
self,
api_key: str,
base_url: str,
model: str,
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
"""
初始化适配器
Args:
api_key: API密钥
base_url: API基础URL
model: 模型名称
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间(秒)
**kwargs: 额外参数
"""
self.api_key = api_key
self.base_url = base_url
self.model = model
self.use_proxy = use_proxy
self.proxy_config = proxy_config or {}
self.timeout = timeout
self.extra_params = kwargs
@abstractmethod
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""
发送聊天请求
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Returns:
适配器响应
"""
pass
@abstractmethod
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""
发送流式聊天请求
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Yields:
响应内容片段
"""
pass
@abstractmethod
async def test_connection(self) -> Dict[str, Any]:
"""
测试API连接
Returns:
测试结果字典,包含 success, message, latency_ms
"""
pass
def _build_messages(
self,
messages: List[ChatMessage]
) -> List[Dict[str, Any]]:
"""
构建消息列表
Args:
messages: ChatMessage列表
Returns:
字典格式的消息列表
"""
return [msg.to_dict() for msg in messages]
def _calculate_latency(self, start_time: datetime) -> float:
"""
计算延迟
Args:
start_time: 开始时间
Returns:
延迟毫秒数
"""
return (datetime.utcnow() - start_time).total_seconds() * 1000

View File

@@ -0,0 +1,197 @@
"""
DeepSeek适配器
支持DeepSeek大模型API
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class DeepSeekAdapter(BaseAdapter):
"""
DeepSeek API适配器
兼容OpenAI API格式
"""
DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "deepseek-chat",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"DeepSeek API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time),
tool_calls=message.get("tool_calls", [])
)
except Exception as e:
logger.error(f"DeepSeek请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"DeepSeek流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="你好,请回复'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

View File

@@ -0,0 +1,250 @@
"""
Gemini适配器
支持Google Gemini大模型API
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class GeminiAdapter(BaseAdapter):
"""
Google Gemini API适配器
使用Gemini的原生API格式
"""
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "gemini-1.5-pro",
use_proxy: bool = True, # Gemini通常需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
def _convert_messages_to_gemini(
self,
messages: List[ChatMessage]
) -> tuple[str, List[Dict[str, Any]]]:
"""
将消息转换为Gemini格式
Args:
messages: 标准消息列表
Returns:
(system_instruction, contents)
"""
system_instruction = ""
contents = []
for msg in messages:
if msg.role == "system":
system_instruction += msg.content + "\n"
else:
role = "user" if msg.role == "user" else "model"
contents.append({
"role": role,
"parts": [{"text": msg.content}]
})
return system_instruction.strip(), contents
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
system_instruction, contents = self._convert_messages_to_gemini(messages)
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"topP": kwargs.get("top_p", 0.95),
"topK": kwargs.get("top_k", 40)
}
}
# 添加系统指令
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
response = await client.post(
url,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
# 检查是否有候选回复
candidates = data.get("candidates", [])
if not candidates:
return AdapterResponse(
success=False,
error="没有生成回复",
latency_ms=self._calculate_latency(start_time)
)
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
text = "".join(part.get("text", "") for part in parts)
# 获取token使用情况
usage = data.get("usageMetadata", {})
return AdapterResponse(
success=True,
content=text,
model=self.model,
finish_reason=candidate.get("finishReason", ""),
prompt_tokens=usage.get("promptTokenCount", 0),
completion_tokens=usage.get("candidatesTokenCount", 0),
total_tokens=usage.get("totalTokenCount", 0),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"Gemini请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
system_instruction, contents = self._convert_messages_to_gemini(messages)
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"topP": kwargs.get("top_p", 0.95),
"topK": kwargs.get("top_k", 40)
}
}
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
async with client.stream(
"POST",
url,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
try:
data = json.loads(data_str)
candidates = data.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
for part in parts:
text = part.get("text", "")
if text:
yield text
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Gemini流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

View File

@@ -0,0 +1,197 @@
"""
Kimi适配器
支持月之暗面Kimi大模型API
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class KimiAdapter(BaseAdapter):
"""
Kimi API适配器
兼容OpenAI API格式
"""
DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "moonshot-v1-8k",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Kimi API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time),
tool_calls=message.get("tool_calls", [])
)
except Exception as e:
logger.error(f"Kimi请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Kimi流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="你好,请回复'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

View File

@@ -0,0 +1,253 @@
"""
LLM Studio适配器
支持本地LLM Studio服务
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class LLMStudioAdapter(BaseAdapter):
"""
LLM Studio API适配器
兼容OpenAI API格式的本地服务
"""
DEFAULT_BASE_URL = "http://localhost:1234/v1"
def __init__(
self,
api_key: str = "lm-studio", # LLM Studio使用固定key
base_url: str = "",
model: str = "local-model",
use_proxy: bool = False, # 本地服务不需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 120, # 本地模型可能需要更长时间
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"LLM Studio请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"LLM Studio流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
# 首先检查服务是否在运行
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
# 获取模型列表
response = await client.get(
f"{self.base_url}/models",
headers={"Authorization": f"Bearer {self.api_key}"}
)
if response.status_code != 200:
return {
"success": False,
"message": "LLM Studio服务未运行或不可访问",
"latency_ms": self._calculate_latency(start_time)
}
data = response.json()
models = [m.get("id", "") for m in data.get("data", [])]
if not models:
return {
"success": False,
"message": "LLM Studio中没有加载的模型",
"latency_ms": self._calculate_latency(start_time)
}
# 发送测试消息
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}
async def list_models(self) -> List[Dict[str, Any]]:
"""
列出LLM Studio中加载的模型
Returns:
模型信息列表
"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
response = await client.get(
f"{self.base_url}/models",
headers={"Authorization": f"Bearer {self.api_key}"}
)
if response.status_code == 200:
data = response.json()
return data.get("data", [])
except Exception as e:
logger.error(f"获取LLM Studio模型列表失败: {e}")
return []

View File

@@ -0,0 +1,251 @@
"""
MiniMax适配器
支持MiniMax大模型API
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class MiniMaxAdapter(BaseAdapter):
"""
MiniMax API适配器
支持abab系列模型
"""
DEFAULT_BASE_URL = "https://api.minimax.chat/v1"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "abab6.5-chat",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
# MiniMax需要group_id
self.group_id = kwargs.get("group_id", "")
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# MiniMax使用特殊的消息格式
minimax_messages = []
bot_setting = []
for msg in messages:
if msg.role == "system":
bot_setting.append({
"bot_name": "assistant",
"content": msg.content
})
else:
minimax_messages.append({
"sender_type": "USER" if msg.role == "user" else "BOT",
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
"text": msg.content
})
payload = {
"model": self.model,
"messages": minimax_messages,
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
"temperature": temperature,
"tokens_to_generate": max_tokens,
"mask_sensitive_info": False,
**kwargs
}
url = f"{self.base_url}/text/chatcompletion_v2"
if self.group_id:
url = f"{url}?GroupId={self.group_id}"
response = await client.post(
url,
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"MiniMax API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
# 检查API返回的错误
if data.get("base_resp", {}).get("status_code", 0) != 0:
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
return AdapterResponse(
success=False,
error=f"API错误: {error_msg}",
latency_ms=self._calculate_latency(start_time)
)
reply = data.get("reply", "")
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=reply,
model=self.model,
finish_reason=data.get("output_sensitive", False) and "content_filter" or "stop",
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"MiniMax请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
minimax_messages = []
bot_setting = []
for msg in messages:
if msg.role == "system":
bot_setting.append({
"bot_name": "assistant",
"content": msg.content
})
else:
minimax_messages.append({
"sender_type": "USER" if msg.role == "user" else "BOT",
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
"text": msg.content
})
payload = {
"model": self.model,
"messages": minimax_messages,
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
"temperature": temperature,
"tokens_to_generate": max_tokens,
"stream": True,
**kwargs
}
url = f"{self.base_url}/text/chatcompletion_v2"
if self.group_id:
url = f"{url}?GroupId={self.group_id}"
async with client.stream(
"POST",
url,
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"MiniMax流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="你好,请回复'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

View File

@@ -0,0 +1,241 @@
"""
Ollama适配器
支持本地Ollama服务
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class OllamaAdapter(BaseAdapter):
"""
Ollama API适配器
用于连接本地Ollama服务
"""
DEFAULT_BASE_URL = "http://localhost:11434"
def __init__(
self,
api_key: str = "", # Ollama通常不需要API密钥
base_url: str = "",
model: str = "llama2",
use_proxy: bool = False, # 本地服务通常不需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 120, # 本地模型可能需要更长时间
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
"stream": False
}
response = await client.post(
f"{self.base_url}/api/chat",
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Ollama API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
message = data.get("message", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=data.get("done_reason", "stop"),
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=(
data.get("prompt_eval_count", 0) +
data.get("eval_count", 0)
),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"Ollama请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
"stream": True
}
async with client.stream(
"POST",
f"{self.base_url}/api/chat",
json=payload
) as response:
async for line in response.aiter_lines():
if line:
try:
data = json.loads(line)
message = data.get("message", {})
content = message.get("content", "")
if content:
yield content
# 检查是否完成
if data.get("done", False):
break
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Ollama流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
# 首先检查服务是否在运行
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
# 检查模型是否存在
response = await client.get(f"{self.base_url}/api/tags")
if response.status_code != 200:
return {
"success": False,
"message": "Ollama服务未运行或不可访问",
"latency_ms": self._calculate_latency(start_time)
}
data = response.json()
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
model_name = self.model.split(":")[0]
if model_name not in models:
return {
"success": False,
"message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}",
"latency_ms": self._calculate_latency(start_time)
}
# 发送测试消息
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}
async def list_models(self) -> List[str]:
"""
列出本地可用的模型
Returns:
模型名称列表
"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=10
) as client:
response = await client.get(f"{self.base_url}/api/tags")
if response.status_code == 200:
data = response.json()
return [m.get("name", "") for m in data.get("models", [])]
except Exception as e:
logger.error(f"获取Ollama模型列表失败: {e}")
return []

View File

@@ -0,0 +1,201 @@
"""
OpenRouter适配器
支持通过OpenRouter访问多种AI模型
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class OpenRouterAdapter(BaseAdapter):
"""
OpenRouter API适配器
兼容OpenAI API格式
"""
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "openai/gpt-4-turbo",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
"X-Title": kwargs.get("title", "AI ChatRoom")
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"OpenRouter API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time),
tool_calls=message.get("tool_calls", [])
)
except Exception as e:
logger.error(f"OpenRouter请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
"X-Title": kwargs.get("title", "AI ChatRoom")
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"OpenRouter流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

View File

@@ -0,0 +1,197 @@
"""
智谱AI适配器
支持智谱GLM系列模型
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class ZhipuAdapter(BaseAdapter):
"""
智谱AI API适配器
支持GLM-4、GLM-3等模型
"""
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "glm-4",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"智谱API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time),
tool_calls=message.get("tool_calls", [])
)
except Exception as e:
logger.error(f"智谱API请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"智谱流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="你好,请回复'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}

50
backend/config.py Normal file
View File

@@ -0,0 +1,50 @@
"""
应用配置模块
从环境变量加载配置项
"""
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
"""应用配置类"""
# MongoDB配置
MONGODB_URL: str = "mongodb://localhost:27017"
MONGODB_DB: str = "ai_chatroom"
# 服务配置
HOST: str = "0.0.0.0"
PORT: int = 8000
DEBUG: bool = True
# 安全配置
SECRET_KEY: str = "your-secret-key-change-in-production"
ENCRYPTION_KEY: str = "your-encryption-key-32-bytes-long"
# CORS配置
CORS_ORIGINS: list = ["http://localhost:3000", "http://127.0.0.1:3000"]
# WebSocket配置
WS_HEARTBEAT_INTERVAL: int = 30
# 默认AI配置
DEFAULT_TIMEOUT: int = 60
DEFAULT_MAX_TOKENS: int = 2000
DEFAULT_TEMPERATURE: float = 0.7
# 代理配置(全局默认)
DEFAULT_HTTP_PROXY: Optional[str] = None
DEFAULT_HTTPS_PROXY: Optional[str] = None
# 速率限制
RATE_LIMIT_REQUESTS: int = 100
RATE_LIMIT_PERIOD: int = 60 # 秒
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
# 全局配置实例
settings = Settings()

View File

@@ -0,0 +1,10 @@
"""
数据库模块
"""
from .connection import connect_db, close_db, get_database
__all__ = [
"connect_db",
"close_db",
"get_database",
]

View File

@@ -0,0 +1,94 @@
"""
MongoDB数据库连接模块
使用Motor异步驱动
"""
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from beanie import init_beanie
from loguru import logger
from typing import Optional
from config import settings
# 全局数据库客户端和数据库实例
_client: Optional[AsyncIOMotorClient] = None
_database: Optional[AsyncIOMotorDatabase] = None
async def connect_db() -> None:
"""
连接MongoDB数据库
初始化Beanie ODM
"""
global _client, _database
try:
_client = AsyncIOMotorClient(settings.MONGODB_URL)
_database = _client[settings.MONGODB_DB]
# 导入所有文档模型用于初始化Beanie
from models.ai_provider import AIProvider
from models.agent import Agent
from models.chatroom import ChatRoom
from models.message import Message
from models.discussion_result import DiscussionResult
from models.agent_memory import AgentMemory
# 初始化Beanie
await init_beanie(
database=_database,
document_models=[
AIProvider,
Agent,
ChatRoom,
Message,
DiscussionResult,
AgentMemory,
]
)
logger.info(f"已连接到MongoDB数据库: {settings.MONGODB_DB}")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
raise
async def close_db() -> None:
"""
关闭数据库连接
"""
global _client
if _client:
_client.close()
logger.info("数据库连接已关闭")
def get_database() -> AsyncIOMotorDatabase:
"""
获取数据库实例
Returns:
MongoDB数据库实例
Raises:
RuntimeError: 数据库未初始化
"""
if _database is None:
raise RuntimeError("数据库未初始化请先调用connect_db()")
return _database
def get_client() -> AsyncIOMotorClient:
"""
获取数据库客户端
Returns:
MongoDB客户端实例
Raises:
RuntimeError: 客户端未初始化
"""
if _client is None:
raise RuntimeError("数据库客户端未初始化")
return _client

73
backend/main.py Normal file
View File

@@ -0,0 +1,73 @@
"""
AI聊天室后端主入口
FastAPI应用启动文件
"""
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from loguru import logger
from config import settings
from database.connection import connect_db, close_db
from routers import providers, agents, chatrooms, discussions
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
启动时连接数据库,关闭时断开连接
"""
logger.info("正在启动AI聊天室服务...")
await connect_db()
logger.info("数据库连接成功")
yield
logger.info("正在关闭AI聊天室服务...")
await close_db()
logger.info("服务已关闭")
# 创建FastAPI应用
app = FastAPI(
title="AI聊天室",
description="多Agent协作讨论平台",
version="1.0.0",
lifespan=lifespan
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(providers.router, prefix="/api/providers", tags=["AI接口管理"])
app.include_router(agents.router, prefix="/api/agents", tags=["Agent管理"])
app.include_router(chatrooms.router, prefix="/api/chatrooms", tags=["聊天室管理"])
app.include_router(discussions.router, prefix="/api/discussions", tags=["讨论结果"])
@app.get("/")
async def root():
"""根路径健康检查"""
return {"message": "AI聊天室服务运行中", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG
)

View File

@@ -0,0 +1,25 @@
"""
数据模型模块
"""
from .ai_provider import AIProvider, ProxyConfig, RateLimit
from .agent import Agent, AgentCapabilities, AgentBehavior
from .chatroom import ChatRoom, ChatRoomConfig
from .message import Message, MessageType
from .discussion_result import DiscussionResult
from .agent_memory import AgentMemory, MemoryType
__all__ = [
"AIProvider",
"ProxyConfig",
"RateLimit",
"Agent",
"AgentCapabilities",
"AgentBehavior",
"ChatRoom",
"ChatRoomConfig",
"Message",
"MessageType",
"DiscussionResult",
"AgentMemory",
"MemoryType",
]

168
backend/models/agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""
Agent数据模型
定义AI聊天代理的配置结构
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from pydantic import Field
from beanie import Document
class AgentCapabilities:
"""Agent能力配置"""
memory_enabled: bool = False # 是否启用记忆
mcp_tools: List[str] = [] # 可用的MCP工具
skills: List[str] = [] # 可用的技能
multimodal: bool = False # 是否支持多模态
def __init__(
self,
memory_enabled: bool = False,
mcp_tools: Optional[List[str]] = None,
skills: Optional[List[str]] = None,
multimodal: bool = False
):
self.memory_enabled = memory_enabled
self.mcp_tools = mcp_tools or []
self.skills = skills or []
self.multimodal = multimodal
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"memory_enabled": self.memory_enabled,
"mcp_tools": self.mcp_tools,
"skills": self.skills,
"multimodal": self.multimodal
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentCapabilities":
"""从字典创建"""
if not data:
return cls()
return cls(
memory_enabled=data.get("memory_enabled", False),
mcp_tools=data.get("mcp_tools", []),
skills=data.get("skills", []),
multimodal=data.get("multimodal", False)
)
class AgentBehavior:
"""Agent行为配置"""
speak_threshold: float = 0.5 # 发言阈值(判断是否需要发言)
max_speak_per_round: int = 2 # 每轮最多发言次数
speak_style: str = "balanced" # 发言风格: concise, balanced, detailed
def __init__(
self,
speak_threshold: float = 0.5,
max_speak_per_round: int = 2,
speak_style: str = "balanced"
):
self.speak_threshold = speak_threshold
self.max_speak_per_round = max_speak_per_round
self.speak_style = speak_style
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"speak_threshold": self.speak_threshold,
"max_speak_per_round": self.max_speak_per_round,
"speak_style": self.speak_style
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentBehavior":
"""从字典创建"""
if not data:
return cls()
return cls(
speak_threshold=data.get("speak_threshold", 0.5),
max_speak_per_round=data.get("max_speak_per_round", 2),
speak_style=data.get("speak_style", "balanced")
)
class Agent(Document):
"""
Agent文档模型
存储AI代理的配置信息
"""
agent_id: str = Field(..., description="唯一标识")
name: str = Field(..., description="Agent名称")
role: str = Field(..., description="角色定义")
system_prompt: str = Field(..., description="系统提示词")
provider_id: str = Field(..., description="使用的AI接口ID")
# 模型参数
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
# 能力配置
capabilities: Dict[str, Any] = Field(
default_factory=lambda: {
"memory_enabled": False,
"mcp_tools": [],
"skills": [],
"multimodal": False
},
description="能力配置"
)
# 行为配置
behavior: Dict[str, Any] = Field(
default_factory=lambda: {
"speak_threshold": 0.5,
"max_speak_per_round": 2,
"speak_style": "balanced"
},
description="行为配置"
)
# 外观配置
avatar: Optional[str] = Field(default=None, description="头像URL")
color: str = Field(default="#1890ff", description="代表颜色")
# 元数据
enabled: bool = Field(default=True, description="是否启用")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class Settings:
name = "agents"
def get_capabilities(self) -> AgentCapabilities:
"""获取能力配置对象"""
return AgentCapabilities.from_dict(self.capabilities)
def get_behavior(self) -> AgentBehavior:
"""获取行为配置对象"""
return AgentBehavior.from_dict(self.behavior)
class Config:
json_schema_extra = {
"example": {
"agent_id": "product-manager",
"name": "产品经理",
"role": "产品规划和需求分析专家",
"system_prompt": "你是一位经验丰富的产品经理,擅长分析用户需求...",
"provider_id": "openrouter-gpt4",
"temperature": 0.7,
"max_tokens": 2000,
"capabilities": {
"memory_enabled": True,
"mcp_tools": ["web_search"],
"skills": [],
"multimodal": False
},
"behavior": {
"speak_threshold": 0.5,
"max_speak_per_round": 2,
"speak_style": "balanced"
},
"avatar": "https://example.com/avatar.png",
"color": "#1890ff"
}
}

View File

@@ -0,0 +1,123 @@
"""
Agent记忆数据模型
定义Agent的记忆存储结构
"""
from datetime import datetime
from typing import Optional, List
from enum import Enum
from pydantic import Field
from beanie import Document
class MemoryType(str, Enum):
"""记忆类型枚举"""
SHORT_TERM = "short_term" # 短期记忆(会话内)
LONG_TERM = "long_term" # 长期记忆(跨会话)
EPISODIC = "episodic" # 情景记忆(特定事件)
SEMANTIC = "semantic" # 语义记忆(知识性)
class AgentMemory(Document):
"""
Agent记忆文档模型
存储Agent的记忆内容
"""
memory_id: str = Field(..., description="唯一标识")
agent_id: str = Field(..., description="Agent ID")
# 记忆内容
memory_type: str = Field(
default=MemoryType.SHORT_TERM.value,
description="记忆类型"
)
content: str = Field(..., description="记忆内容")
summary: str = Field(default="", description="内容摘要")
# 向量嵌入(用于相似度检索)
embedding: List[float] = Field(default_factory=list, description="向量嵌入")
# 元数据
importance: float = Field(default=0.5, ge=0, le=1, description="重要性评分")
access_count: int = Field(default=0, description="访问次数")
# 关联信息
source_room_id: Optional[str] = Field(default=None, description="来源聊天室ID")
source_discussion_id: Optional[str] = Field(default=None, description="来源讨论ID")
related_agents: List[str] = Field(default_factory=list, description="相关Agent列表")
tags: List[str] = Field(default_factory=list, description="标签")
# 时间戳
created_at: datetime = Field(default_factory=datetime.utcnow)
last_accessed: datetime = Field(default_factory=datetime.utcnow)
expires_at: Optional[datetime] = Field(default=None, description="过期时间")
class Settings:
name = "agent_memories"
indexes = [
[("agent_id", 1)],
[("memory_type", 1)],
[("importance", -1)],
[("last_accessed", -1)],
]
def access(self) -> None:
"""
记录访问,更新访问计数和时间
"""
self.access_count += 1
self.last_accessed = datetime.utcnow()
def is_expired(self) -> bool:
"""
检查记忆是否已过期
Returns:
是否过期
"""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def calculate_relevance_score(
self,
similarity: float,
time_decay_factor: float = 0.1
) -> float:
"""
计算综合相关性分数
结合向量相似度、重要性和时间衰减
Args:
similarity: 向量相似度 (0-1)
time_decay_factor: 时间衰减因子
Returns:
综合相关性分数
"""
# 计算时间衰减
hours_since_access = (datetime.utcnow() - self.last_accessed).total_seconds() / 3600
time_decay = 1.0 / (1.0 + time_decay_factor * hours_since_access)
# 综合评分
score = (
0.5 * similarity +
0.3 * self.importance +
0.2 * time_decay
)
return min(1.0, max(0.0, score))
class Config:
json_schema_extra = {
"example": {
"memory_id": "mem-001",
"agent_id": "product-manager",
"memory_type": "long_term",
"content": "在登录系统设计讨论中团队决定采用OAuth2.0方案",
"summary": "登录系统采用OAuth2.0",
"importance": 0.8,
"access_count": 5,
"source_room_id": "product-design-room",
"tags": ["登录", "OAuth", "认证"]
}
}

View File

@@ -0,0 +1,149 @@
"""
AI接口提供商数据模型
定义AI服务配置结构
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from enum import Enum
from pydantic import Field
from beanie import Document
class ProviderType(str, Enum):
"""AI提供商类型枚举"""
MINIMAX = "minimax"
ZHIPU = "zhipu"
OPENROUTER = "openrouter"
KIMI = "kimi"
DEEPSEEK = "deepseek"
GEMINI = "gemini"
OLLAMA = "ollama"
LLMSTUDIO = "llmstudio"
class ProxyConfig:
"""代理配置"""
http_proxy: Optional[str] = None # HTTP代理地址
https_proxy: Optional[str] = None # HTTPS代理地址
no_proxy: List[str] = [] # 不使用代理的域名列表
def __init__(
self,
http_proxy: Optional[str] = None,
https_proxy: Optional[str] = None,
no_proxy: Optional[List[str]] = None
):
self.http_proxy = http_proxy
self.https_proxy = https_proxy
self.no_proxy = no_proxy or []
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"http_proxy": self.http_proxy,
"https_proxy": self.https_proxy,
"no_proxy": self.no_proxy
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ProxyConfig":
"""从字典创建"""
if not data:
return cls()
return cls(
http_proxy=data.get("http_proxy"),
https_proxy=data.get("https_proxy"),
no_proxy=data.get("no_proxy", [])
)
class RateLimit:
"""速率限制配置"""
requests_per_minute: int = 60 # 每分钟请求数
tokens_per_minute: int = 100000 # 每分钟token数
def __init__(
self,
requests_per_minute: int = 60,
tokens_per_minute: int = 100000
):
self.requests_per_minute = requests_per_minute
self.tokens_per_minute = tokens_per_minute
def to_dict(self) -> Dict[str, int]:
"""转换为字典"""
return {
"requests_per_minute": self.requests_per_minute,
"tokens_per_minute": self.tokens_per_minute
}
@classmethod
def from_dict(cls, data: Dict[str, int]) -> "RateLimit":
"""从字典创建"""
if not data:
return cls()
return cls(
requests_per_minute=data.get("requests_per_minute", 60),
tokens_per_minute=data.get("tokens_per_minute", 100000)
)
class AIProvider(Document):
"""
AI接口提供商文档模型
存储各AI服务的配置信息
"""
provider_id: str = Field(..., description="唯一标识")
provider_type: str = Field(..., description="提供商类型: minimax, zhipu等")
name: str = Field(..., description="自定义名称")
api_key: str = Field(default="", description="API密钥(加密存储)")
base_url: str = Field(default="", description="API基础URL")
model: str = Field(..., description="使用的模型名称")
# 代理配置
use_proxy: bool = Field(default=False, description="是否使用代理")
proxy_config: Dict[str, Any] = Field(default_factory=dict, description="代理配置")
# 速率限制
rate_limit: Dict[str, int] = Field(
default_factory=lambda: {"requests_per_minute": 60, "tokens_per_minute": 100000},
description="速率限制配置"
)
# 其他配置
timeout: int = Field(default=60, description="超时时间(秒)")
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
# 元数据
enabled: bool = Field(default=True, description="是否启用")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class Settings:
name = "ai_providers"
def get_proxy_config(self) -> ProxyConfig:
"""获取代理配置对象"""
return ProxyConfig.from_dict(self.proxy_config)
def get_rate_limit(self) -> RateLimit:
"""获取速率限制配置对象"""
return RateLimit.from_dict(self.rate_limit)
class Config:
json_schema_extra = {
"example": {
"provider_id": "openrouter-gpt4",
"provider_type": "openrouter",
"name": "OpenRouter GPT-4",
"api_key": "sk-xxx",
"base_url": "https://openrouter.ai/api/v1",
"model": "openai/gpt-4-turbo",
"use_proxy": True,
"proxy_config": {
"http_proxy": "http://127.0.0.1:7890",
"https_proxy": "http://127.0.0.1:7890"
},
"timeout": 60
}
}

131
backend/models/chatroom.py Normal file
View File

@@ -0,0 +1,131 @@
"""
聊天室数据模型
定义讨论聊天室的配置结构
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from enum import Enum
from pydantic import Field
from beanie import Document
class ChatRoomStatus(str, Enum):
"""聊天室状态枚举"""
IDLE = "idle" # 空闲,等待开始
ACTIVE = "active" # 讨论进行中
PAUSED = "paused" # 暂停
COMPLETED = "completed" # 已完成
ERROR = "error" # 出错
class ChatRoomConfig:
"""聊天室配置"""
max_rounds: int = 50 # 最大轮数(备用终止条件)
message_history_size: int = 20 # 上下文消息数
consensus_threshold: float = 0.8 # 共识阈值
round_interval: float = 1.0 # 轮次间隔(秒)
allow_user_interrupt: bool = True # 允许用户中断
def __init__(
self,
max_rounds: int = 50,
message_history_size: int = 20,
consensus_threshold: float = 0.8,
round_interval: float = 1.0,
allow_user_interrupt: bool = True
):
self.max_rounds = max_rounds
self.message_history_size = message_history_size
self.consensus_threshold = consensus_threshold
self.round_interval = round_interval
self.allow_user_interrupt = allow_user_interrupt
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"max_rounds": self.max_rounds,
"message_history_size": self.message_history_size,
"consensus_threshold": self.consensus_threshold,
"round_interval": self.round_interval,
"allow_user_interrupt": self.allow_user_interrupt
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatRoomConfig":
"""从字典创建"""
if not data:
return cls()
return cls(
max_rounds=data.get("max_rounds", 50),
message_history_size=data.get("message_history_size", 20),
consensus_threshold=data.get("consensus_threshold", 0.8),
round_interval=data.get("round_interval", 1.0),
allow_user_interrupt=data.get("allow_user_interrupt", True)
)
class ChatRoom(Document):
"""
聊天室文档模型
存储讨论聊天室的配置信息
"""
room_id: str = Field(..., description="唯一标识")
name: str = Field(..., description="聊天室名称")
description: str = Field(default="", description="描述")
objective: str = Field(default="", description="当前讨论目标")
# 参与者
agents: List[str] = Field(default_factory=list, description="Agent ID列表")
moderator_agent_id: Optional[str] = Field(default=None, description="共识判断Agent ID")
# 配置
config: Dict[str, Any] = Field(
default_factory=lambda: {
"max_rounds": 50,
"message_history_size": 20,
"consensus_threshold": 0.8,
"round_interval": 1.0,
"allow_user_interrupt": True
},
description="聊天室配置"
)
# 状态
status: str = Field(default=ChatRoomStatus.IDLE.value, description="当前状态")
current_round: int = Field(default=0, description="当前轮次")
current_discussion_id: Optional[str] = Field(default=None, description="当前讨论ID")
# 元数据
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
class Settings:
name = "chatrooms"
def get_config(self) -> ChatRoomConfig:
"""获取配置对象"""
return ChatRoomConfig.from_dict(self.config)
def is_active(self) -> bool:
"""检查聊天室是否处于活动状态"""
return self.status == ChatRoomStatus.ACTIVE.value
class Config:
json_schema_extra = {
"example": {
"room_id": "product-design-room",
"name": "产品设计讨论室",
"description": "用于讨论新产品功能设计",
"objective": "设计一个用户友好的登录系统",
"agents": ["product-manager", "designer", "developer"],
"moderator_agent_id": "moderator",
"config": {
"max_rounds": 50,
"message_history_size": 20,
"consensus_threshold": 0.8
},
"status": "idle",
"current_round": 0
}
}

View File

@@ -0,0 +1,126 @@
"""
讨论结果数据模型
定义讨论结果的结构
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from pydantic import Field
from beanie import Document
class DiscussionResult(Document):
"""
讨论结果文档模型
存储讨论的最终结果
"""
discussion_id: str = Field(..., description="讨论唯一标识")
room_id: str = Field(..., description="聊天室ID")
objective: str = Field(..., description="讨论目标")
# 共识结果
consensus_reached: bool = Field(default=False, description="是否达成共识")
confidence: float = Field(default=0.0, ge=0, le=1, description="共识置信度")
# 结果摘要
summary: str = Field(default="", description="讨论结果摘要")
action_items: List[str] = Field(default_factory=list, description="行动项列表")
unresolved_issues: List[str] = Field(default_factory=list, description="未解决的问题")
key_decisions: List[str] = Field(default_factory=list, description="关键决策")
# 统计信息
total_rounds: int = Field(default=0, description="总轮数")
total_messages: int = Field(default=0, description="总消息数")
participating_agents: List[str] = Field(default_factory=list, description="参与的Agent列表")
agent_contributions: Dict[str, int] = Field(
default_factory=dict,
description="各Agent发言次数统计"
)
# 状态
status: str = Field(default="in_progress", description="状态: in_progress, completed, failed")
end_reason: str = Field(default="", description="结束原因")
# 时间戳
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
class Settings:
name = "discussions"
indexes = [
[("room_id", 1)],
[("created_at", -1)],
]
def mark_completed(
self,
consensus_reached: bool,
confidence: float,
summary: str,
action_items: List[str] = None,
unresolved_issues: List[str] = None,
end_reason: str = "consensus"
) -> None:
"""
标记讨论为已完成
Args:
consensus_reached: 是否达成共识
confidence: 置信度
summary: 结果摘要
action_items: 行动项
unresolved_issues: 未解决问题
end_reason: 结束原因
"""
self.consensus_reached = consensus_reached
self.confidence = confidence
self.summary = summary
self.action_items = action_items or []
self.unresolved_issues = unresolved_issues or []
self.status = "completed"
self.end_reason = end_reason
self.completed_at = datetime.utcnow()
self.updated_at = datetime.utcnow()
def update_stats(
self,
total_rounds: int,
total_messages: int,
agent_contributions: Dict[str, int]
) -> None:
"""
更新统计信息
Args:
total_rounds: 总轮数
total_messages: 总消息数
agent_contributions: Agent贡献统计
"""
self.total_rounds = total_rounds
self.total_messages = total_messages
self.agent_contributions = agent_contributions
self.participating_agents = list(agent_contributions.keys())
self.updated_at = datetime.utcnow()
class Config:
json_schema_extra = {
"example": {
"discussion_id": "disc-001",
"room_id": "product-design-room",
"objective": "设计用户登录系统",
"consensus_reached": True,
"confidence": 0.85,
"summary": "团队一致同意采用OAuth2.0 + 手机验证码的混合认证方案...",
"action_items": [
"设计OAuth2.0集成方案",
"开发短信验证服务",
"编写安全测试用例"
],
"unresolved_issues": [
"第三方登录的优先级排序"
],
"total_rounds": 15,
"total_messages": 45,
"status": "completed"
}
}

123
backend/models/message.py Normal file
View File

@@ -0,0 +1,123 @@
"""
消息数据模型
定义聊天消息的结构
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from enum import Enum
from pydantic import Field
from beanie import Document
class MessageType(str, Enum):
"""消息类型枚举"""
TEXT = "text" # 纯文本
IMAGE = "image" # 图片
FILE = "file" # 文件
SYSTEM = "system" # 系统消息
ACTION = "action" # 动作消息(如调用工具)
class MessageAttachment:
"""消息附件"""
attachment_type: str # 附件类型: image, file
url: str # 资源URL
name: str # 文件名
size: int = 0 # 文件大小(字节)
mime_type: str = "" # MIME类型
def __init__(
self,
attachment_type: str,
url: str,
name: str,
size: int = 0,
mime_type: str = ""
):
self.attachment_type = attachment_type
self.url = url
self.name = name
self.size = size
self.mime_type = mime_type
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"attachment_type": self.attachment_type,
"url": self.url,
"name": self.name,
"size": self.size,
"mime_type": self.mime_type
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MessageAttachment":
"""从字典创建"""
return cls(
attachment_type=data.get("attachment_type", ""),
url=data.get("url", ""),
name=data.get("name", ""),
size=data.get("size", 0),
mime_type=data.get("mime_type", "")
)
class Message(Document):
"""
消息文档模型
存储聊天消息
"""
message_id: str = Field(..., description="唯一标识")
room_id: str = Field(..., description="聊天室ID")
discussion_id: str = Field(..., description="讨论ID")
agent_id: Optional[str] = Field(default=None, description="发送Agent ID(系统消息为空)")
# 消息内容
content: str = Field(..., description="消息内容")
message_type: str = Field(default=MessageType.TEXT.value, description="消息类型")
attachments: List[Dict[str, Any]] = Field(default_factory=list, description="附件列表")
# 元数据
round: int = Field(default=0, description="所属轮次")
token_count: int = Field(default=0, description="token数量")
# 工具调用相关
tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用记录")
tool_results: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用结果")
# 时间戳
created_at: datetime = Field(default_factory=datetime.utcnow)
class Settings:
name = "messages"
indexes = [
[("room_id", 1), ("created_at", 1)],
[("discussion_id", 1)],
[("agent_id", 1)],
]
def get_attachments(self) -> List[MessageAttachment]:
"""获取附件对象列表"""
return [MessageAttachment.from_dict(a) for a in self.attachments]
def is_from_agent(self, agent_id: str) -> bool:
"""检查消息是否来自指定Agent"""
return self.agent_id == agent_id
def is_system_message(self) -> bool:
"""检查是否为系统消息"""
return self.message_type == MessageType.SYSTEM.value
class Config:
json_schema_extra = {
"example": {
"message_id": "msg-001",
"room_id": "product-design-room",
"discussion_id": "disc-001",
"agent_id": "product-manager",
"content": "我认为登录系统应该支持多种认证方式...",
"message_type": "text",
"round": 1,
"token_count": 150
}
}

42
backend/requirements.txt Normal file
View File

@@ -0,0 +1,42 @@
# FastAPI and server
fastapi==0.109.0
uvicorn[standard]==0.27.0
python-multipart==0.0.6
websockets==12.0
# MongoDB
motor==3.3.2
pymongo==4.6.1
beanie==1.25.0
# HTTP client
httpx==0.26.0
aiohttp==3.9.1
# AI SDK clients
openai==1.12.0
google-generativeai==0.3.2
zhipuai==2.0.1
# Data validation
pydantic==2.6.0
pydantic-settings==2.1.0
# Security
cryptography==42.0.2
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
# Utilities
python-dotenv==1.0.1
tenacity==8.2.3
numpy==1.26.4
# For embeddings and vector search
sentence-transformers==2.3.1
# Rate limiting
slowapi==0.1.9
# Logging
loguru==0.7.2

View File

@@ -0,0 +1,14 @@
"""
API路由模块
"""
from . import providers
from . import agents
from . import chatrooms
from . import discussions
__all__ = [
"providers",
"agents",
"chatrooms",
"discussions",
]

314
backend/routers/agents.py Normal file
View File

@@ -0,0 +1,314 @@
"""
Agent管理路由
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from loguru import logger
from services.agent_service import AgentService, AGENT_TEMPLATES
router = APIRouter()
# ============ 请求/响应模型 ============
class CapabilitiesModel(BaseModel):
"""能力配置模型"""
memory_enabled: bool = False
mcp_tools: List[str] = []
skills: List[str] = []
multimodal: bool = False
class BehaviorModel(BaseModel):
"""行为配置模型"""
speak_threshold: float = 0.5
max_speak_per_round: int = 2
speak_style: str = "balanced"
class AgentCreateRequest(BaseModel):
"""创建Agent请求"""
name: str = Field(..., description="Agent名称")
role: str = Field(..., description="角色定义")
system_prompt: str = Field(..., description="系统提示词")
provider_id: str = Field(..., description="使用的AI接口ID")
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
capabilities: Optional[CapabilitiesModel] = None
behavior: Optional[BehaviorModel] = None
avatar: Optional[str] = None
color: str = "#1890ff"
class Config:
json_schema_extra = {
"example": {
"name": "产品经理",
"role": "产品规划和需求分析专家",
"system_prompt": "你是一位经验丰富的产品经理...",
"provider_id": "openrouter-abc123",
"temperature": 0.7,
"max_tokens": 2000
}
}
class AgentUpdateRequest(BaseModel):
"""更新Agent请求"""
name: Optional[str] = None
role: Optional[str] = None
system_prompt: Optional[str] = None
provider_id: Optional[str] = None
temperature: Optional[float] = Field(default=None, ge=0, le=2)
max_tokens: Optional[int] = Field(default=None, gt=0)
capabilities: Optional[CapabilitiesModel] = None
behavior: Optional[BehaviorModel] = None
avatar: Optional[str] = None
color: Optional[str] = None
enabled: Optional[bool] = None
class AgentResponse(BaseModel):
"""Agent响应"""
agent_id: str
name: str
role: str
system_prompt: str
provider_id: str
temperature: float
max_tokens: int
capabilities: Dict[str, Any]
behavior: Dict[str, Any]
avatar: Optional[str]
color: str
enabled: bool
created_at: str
updated_at: str
class AgentTestRequest(BaseModel):
"""Agent测试请求"""
message: str = "你好,请简单介绍一下你自己。"
class AgentTestResponse(BaseModel):
"""Agent测试响应"""
success: bool
message: str
response: Optional[str] = None
model: Optional[str] = None
tokens: Optional[int] = None
latency_ms: Optional[float] = None
class TemplateResponse(BaseModel):
"""模板响应"""
template_id: str
name: str
role: str
system_prompt: str
color: str
class GeneratePromptRequest(BaseModel):
"""生成提示词请求"""
provider_id: str = Field(..., description="使用的AI接口ID")
name: str = Field(..., description="Agent名称")
role: str = Field(..., description="角色定位")
description: Optional[str] = Field(None, description="额外描述(可选)")
class GeneratePromptResponse(BaseModel):
"""生成提示词响应"""
success: bool
message: Optional[str] = None
prompt: Optional[str] = None
model: Optional[str] = None
tokens: Optional[int] = None
# ============ 路由处理 ============
@router.post("", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
async def create_agent(request: AgentCreateRequest):
"""
创建新的Agent
"""
try:
agent = await AgentService.create_agent(
name=request.name,
role=request.role,
system_prompt=request.system_prompt,
provider_id=request.provider_id,
temperature=request.temperature,
max_tokens=request.max_tokens,
capabilities=request.capabilities.dict() if request.capabilities else None,
behavior=request.behavior.dict() if request.behavior else None,
avatar=request.avatar,
color=request.color
)
return _to_response(agent)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"创建Agent失败: {e}")
raise HTTPException(status_code=500, detail="创建失败")
@router.get("", response_model=List[AgentResponse])
async def list_agents(enabled_only: bool = False):
"""
获取所有Agent
"""
agents = await AgentService.get_all_agents(enabled_only)
return [_to_response(a) for a in agents]
@router.get("/templates", response_model=List[TemplateResponse])
async def list_templates():
"""
获取Agent预设模板
"""
return [
TemplateResponse(
template_id=tid,
name=t["name"],
role=t["role"],
system_prompt=t["system_prompt"],
color=t["color"]
)
for tid, t in AGENT_TEMPLATES.items()
]
@router.post("/generate-prompt", response_model=GeneratePromptResponse)
async def generate_prompt(request: GeneratePromptRequest):
"""
使用AI生成Agent系统提示词
"""
result = await AgentService.generate_system_prompt(
provider_id=request.provider_id,
name=request.name,
role=request.role,
description=request.description
)
return GeneratePromptResponse(**result)
@router.get("/{agent_id}", response_model=AgentResponse)
async def get_agent(agent_id: str):
"""
获取指定Agent
"""
agent = await AgentService.get_agent(agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent不存在")
return _to_response(agent)
@router.put("/{agent_id}", response_model=AgentResponse)
async def update_agent(agent_id: str, request: AgentUpdateRequest):
"""
更新Agent配置
"""
update_data = request.dict(exclude_unset=True)
# 转换嵌套模型
if "capabilities" in update_data and update_data["capabilities"]:
if hasattr(update_data["capabilities"], "dict"):
update_data["capabilities"] = update_data["capabilities"].dict()
if "behavior" in update_data and update_data["behavior"]:
if hasattr(update_data["behavior"], "dict"):
update_data["behavior"] = update_data["behavior"].dict()
try:
agent = await AgentService.update_agent(agent_id, **update_data)
if not agent:
raise HTTPException(status_code=404, detail="Agent不存在")
return _to_response(agent)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{agent_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_agent(agent_id: str):
"""
删除Agent
"""
success = await AgentService.delete_agent(agent_id)
if not success:
raise HTTPException(status_code=404, detail="Agent不存在")
@router.post("/{agent_id}/test", response_model=AgentTestResponse)
async def test_agent(agent_id: str, request: AgentTestRequest = None):
"""
测试Agent对话
"""
message = request.message if request else "你好,请简单介绍一下你自己。"
result = await AgentService.test_agent(agent_id, message)
return AgentTestResponse(**result)
@router.post("/{agent_id}/duplicate", response_model=AgentResponse)
async def duplicate_agent(agent_id: str, new_name: Optional[str] = None):
"""
复制Agent
"""
agent = await AgentService.duplicate_agent(agent_id, new_name)
if not agent:
raise HTTPException(status_code=404, detail="源Agent不存在")
return _to_response(agent)
@router.post("/from-template/{template_id}", response_model=AgentResponse)
async def create_from_template(template_id: str, provider_id: str):
"""
从模板创建Agent
"""
if template_id not in AGENT_TEMPLATES:
raise HTTPException(status_code=404, detail="模板不存在")
template = AGENT_TEMPLATES[template_id]
try:
agent = await AgentService.create_agent(
name=template["name"],
role=template["role"],
system_prompt=template["system_prompt"],
provider_id=provider_id,
color=template["color"]
)
return _to_response(agent)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# ============ 辅助函数 ============
def _to_response(agent) -> AgentResponse:
"""
转换为响应模型
"""
return AgentResponse(
agent_id=agent.agent_id,
name=agent.name,
role=agent.role,
system_prompt=agent.system_prompt,
provider_id=agent.provider_id,
temperature=agent.temperature,
max_tokens=agent.max_tokens,
capabilities=agent.capabilities,
behavior=agent.behavior,
avatar=agent.avatar,
color=agent.color,
enabled=agent.enabled,
created_at=agent.created_at.isoformat(),
updated_at=agent.updated_at.isoformat()
)

View File

@@ -0,0 +1,387 @@
"""
聊天室管理路由
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status
from pydantic import BaseModel, Field
from loguru import logger
from services.chatroom_service import ChatRoomService
from services.discussion_engine import DiscussionEngine
from services.message_router import MessageRouter
router = APIRouter()
# ============ 请求/响应模型 ============
class ChatRoomConfigModel(BaseModel):
"""聊天室配置模型"""
max_rounds: int = 50
message_history_size: int = 20
consensus_threshold: float = 0.8
round_interval: float = 1.0
allow_user_interrupt: bool = True
class ChatRoomCreateRequest(BaseModel):
"""创建聊天室请求"""
name: str = Field(..., description="聊天室名称")
description: str = Field(default="", description="描述")
agents: List[str] = Field(default=[], description="Agent ID列表")
moderator_agent_id: Optional[str] = Field(default=None, description="主持人Agent ID")
config: Optional[ChatRoomConfigModel] = None
class Config:
json_schema_extra = {
"example": {
"name": "产品设计讨论室",
"description": "用于讨论新产品功能设计",
"agents": ["agent-abc123", "agent-def456"],
"moderator_agent_id": "agent-xyz789"
}
}
class ChatRoomUpdateRequest(BaseModel):
"""更新聊天室请求"""
name: Optional[str] = None
description: Optional[str] = None
agents: Optional[List[str]] = None
moderator_agent_id: Optional[str] = None
config: Optional[ChatRoomConfigModel] = None
class ChatRoomResponse(BaseModel):
"""聊天室响应"""
room_id: str
name: str
description: str
objective: str
agents: List[str]
moderator_agent_id: Optional[str]
config: Dict[str, Any]
status: str
current_round: int
current_discussion_id: Optional[str]
created_at: str
updated_at: str
completed_at: Optional[str]
class MessageResponse(BaseModel):
"""消息响应"""
message_id: str
room_id: str
discussion_id: str
agent_id: Optional[str]
content: str
message_type: str
round: int
created_at: str
class StartDiscussionRequest(BaseModel):
"""启动讨论请求"""
objective: str = Field(..., description="讨论目标")
class DiscussionStatusResponse(BaseModel):
"""讨论状态响应"""
is_active: bool
room_id: str
discussion_id: Optional[str] = None
current_round: int = 0
status: str
# ============ 路由处理 ============
@router.post("", response_model=ChatRoomResponse, status_code=status.HTTP_201_CREATED)
async def create_chatroom(request: ChatRoomCreateRequest):
"""
创建新的聊天室
"""
try:
chatroom = await ChatRoomService.create_chatroom(
name=request.name,
description=request.description,
agents=request.agents,
moderator_agent_id=request.moderator_agent_id,
config=request.config.dict() if request.config else None
)
return _to_response(chatroom)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"创建聊天室失败: {e}")
raise HTTPException(status_code=500, detail="创建失败")
@router.get("", response_model=List[ChatRoomResponse])
async def list_chatrooms():
"""
获取所有聊天室
"""
chatrooms = await ChatRoomService.get_all_chatrooms()
return [_to_response(c) for c in chatrooms]
@router.get("/{room_id}", response_model=ChatRoomResponse)
async def get_chatroom(room_id: str):
"""
获取指定聊天室
"""
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom:
raise HTTPException(status_code=404, detail="聊天室不存在")
return _to_response(chatroom)
@router.put("/{room_id}", response_model=ChatRoomResponse)
async def update_chatroom(room_id: str, request: ChatRoomUpdateRequest):
"""
更新聊天室配置
"""
update_data = request.dict(exclude_unset=True)
if "config" in update_data and update_data["config"]:
if hasattr(update_data["config"], "dict"):
update_data["config"] = update_data["config"].dict()
try:
chatroom = await ChatRoomService.update_chatroom(room_id, **update_data)
if not chatroom:
raise HTTPException(status_code=404, detail="聊天室不存在")
return _to_response(chatroom)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{room_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_chatroom(room_id: str):
"""
删除聊天室
"""
success = await ChatRoomService.delete_chatroom(room_id)
if not success:
raise HTTPException(status_code=404, detail="聊天室不存在")
@router.post("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
async def add_agent_to_chatroom(room_id: str, agent_id: str):
"""
向聊天室添加Agent
"""
try:
chatroom = await ChatRoomService.add_agent(room_id, agent_id)
if not chatroom:
raise HTTPException(status_code=404, detail="聊天室不存在")
return _to_response(chatroom)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
async def remove_agent_from_chatroom(room_id: str, agent_id: str):
"""
从聊天室移除Agent
"""
chatroom = await ChatRoomService.remove_agent(room_id, agent_id)
if not chatroom:
raise HTTPException(status_code=404, detail="聊天室不存在")
return _to_response(chatroom)
@router.get("/{room_id}/messages", response_model=List[MessageResponse])
async def get_chatroom_messages(
room_id: str,
limit: int = 50,
skip: int = 0,
discussion_id: Optional[str] = None
):
"""
获取聊天室消息历史
"""
messages = await ChatRoomService.get_messages(
room_id, limit, skip, discussion_id
)
return [_message_to_response(m) for m in messages]
@router.post("/{room_id}/start", response_model=DiscussionStatusResponse)
async def start_discussion(room_id: str, request: StartDiscussionRequest):
"""
启动讨论
"""
try:
# 异步启动讨论(不等待完成)
import asyncio
asyncio.create_task(
DiscussionEngine.start_discussion(room_id, request.objective)
)
# 等待一小段时间让讨论初始化
await asyncio.sleep(0.5)
chatroom = await ChatRoomService.get_chatroom(room_id)
return DiscussionStatusResponse(
is_active=True,
room_id=room_id,
discussion_id=chatroom.current_discussion_id if chatroom else None,
current_round=chatroom.current_round if chatroom else 0,
status=chatroom.status if chatroom else "unknown"
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/{room_id}/pause", response_model=DiscussionStatusResponse)
async def pause_discussion(room_id: str):
"""
暂停讨论
"""
success = await DiscussionEngine.pause_discussion(room_id)
if not success:
raise HTTPException(status_code=400, detail="没有进行中的讨论")
chatroom = await ChatRoomService.get_chatroom(room_id)
return DiscussionStatusResponse(
is_active=False,
room_id=room_id,
discussion_id=chatroom.current_discussion_id if chatroom else None,
current_round=chatroom.current_round if chatroom else 0,
status="paused"
)
@router.post("/{room_id}/resume", response_model=DiscussionStatusResponse)
async def resume_discussion(room_id: str):
"""
恢复讨论
"""
success = await DiscussionEngine.resume_discussion(room_id)
if not success:
raise HTTPException(status_code=400, detail="聊天室不在暂停状态")
chatroom = await ChatRoomService.get_chatroom(room_id)
return DiscussionStatusResponse(
is_active=True,
room_id=room_id,
discussion_id=chatroom.current_discussion_id if chatroom else None,
current_round=chatroom.current_round if chatroom else 0,
status="active"
)
@router.post("/{room_id}/stop", response_model=DiscussionStatusResponse)
async def stop_discussion(room_id: str):
"""
停止讨论
"""
success = await DiscussionEngine.stop_discussion(room_id)
chatroom = await ChatRoomService.get_chatroom(room_id)
return DiscussionStatusResponse(
is_active=False,
room_id=room_id,
discussion_id=chatroom.current_discussion_id if chatroom else None,
current_round=chatroom.current_round if chatroom else 0,
status="stopping" if success else chatroom.status if chatroom else "unknown"
)
@router.get("/{room_id}/status", response_model=DiscussionStatusResponse)
async def get_discussion_status(room_id: str):
"""
获取讨论状态
"""
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom:
raise HTTPException(status_code=404, detail="聊天室不存在")
is_active = DiscussionEngine.is_discussion_active(room_id)
return DiscussionStatusResponse(
is_active=is_active,
room_id=room_id,
discussion_id=chatroom.current_discussion_id,
current_round=chatroom.current_round,
status=chatroom.status
)
# ============ WebSocket端点 ============
@router.websocket("/ws/{room_id}")
async def chatroom_websocket(websocket: WebSocket, room_id: str):
"""
聊天室WebSocket连接
"""
# 验证聊天室存在
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom:
await websocket.close(code=4004, reason="聊天室不存在")
return
await MessageRouter.connect(room_id, websocket)
try:
while True:
# 保持连接,接收客户端消息(如心跳)
data = await websocket.receive_text()
# 处理心跳
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
await MessageRouter.disconnect(room_id, websocket)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
await MessageRouter.disconnect(room_id, websocket)
# ============ 辅助函数 ============
def _to_response(chatroom) -> ChatRoomResponse:
"""
转换为响应模型
"""
return ChatRoomResponse(
room_id=chatroom.room_id,
name=chatroom.name,
description=chatroom.description,
objective=chatroom.objective,
agents=chatroom.agents,
moderator_agent_id=chatroom.moderator_agent_id,
config=chatroom.config,
status=chatroom.status,
current_round=chatroom.current_round,
current_discussion_id=chatroom.current_discussion_id,
created_at=chatroom.created_at.isoformat(),
updated_at=chatroom.updated_at.isoformat(),
completed_at=chatroom.completed_at.isoformat() if chatroom.completed_at else None
)
def _message_to_response(message) -> MessageResponse:
"""
转换消息为响应模型
"""
return MessageResponse(
message_id=message.message_id,
room_id=message.room_id,
discussion_id=message.discussion_id,
agent_id=message.agent_id,
content=message.content,
message_type=message.message_type,
round=message.round,
created_at=message.created_at.isoformat()
)

View File

@@ -0,0 +1,136 @@
"""
讨论结果路由
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from models.discussion_result import DiscussionResult
router = APIRouter()
# ============ 响应模型 ============
class DiscussionResponse(BaseModel):
"""讨论结果响应"""
discussion_id: str
room_id: str
objective: str
consensus_reached: bool
confidence: float
summary: str
action_items: List[str]
unresolved_issues: List[str]
key_decisions: List[str]
total_rounds: int
total_messages: int
participating_agents: List[str]
agent_contributions: Dict[str, int]
status: str
end_reason: str
created_at: str
completed_at: Optional[str]
class DiscussionListResponse(BaseModel):
"""讨论列表响应"""
discussions: List[DiscussionResponse]
total: int
# ============ 路由处理 ============
@router.get("", response_model=DiscussionListResponse)
async def list_discussions(
room_id: Optional[str] = None,
limit: int = 20,
skip: int = 0
):
"""
获取讨论结果列表
"""
query = {}
if room_id:
query["room_id"] = room_id
discussions = await DiscussionResult.find(query).sort(
"-created_at"
).skip(skip).limit(limit).to_list()
total = await DiscussionResult.find(query).count()
return DiscussionListResponse(
discussions=[_to_response(d) for d in discussions],
total=total
)
@router.get("/{discussion_id}", response_model=DiscussionResponse)
async def get_discussion(discussion_id: str):
"""
获取指定讨论结果
"""
discussion = await DiscussionResult.find_one(
DiscussionResult.discussion_id == discussion_id
)
if not discussion:
raise HTTPException(status_code=404, detail="讨论记录不存在")
return _to_response(discussion)
@router.get("/room/{room_id}", response_model=List[DiscussionResponse])
async def get_room_discussions(room_id: str, limit: int = 10):
"""
获取聊天室的讨论历史
"""
discussions = await DiscussionResult.find(
{"room_id": room_id}
).sort("-created_at").limit(limit).to_list()
return [_to_response(d) for d in discussions]
@router.get("/room/{room_id}/latest", response_model=DiscussionResponse)
async def get_latest_discussion(room_id: str):
"""
获取聊天室最新的讨论结果
"""
discussion = await DiscussionResult.find(
{"room_id": room_id}
).sort("-created_at").first_or_none()
if not discussion:
raise HTTPException(status_code=404, detail="没有找到讨论记录")
return _to_response(discussion)
# ============ 辅助函数 ============
def _to_response(discussion: DiscussionResult) -> DiscussionResponse:
"""
转换为响应模型
"""
return DiscussionResponse(
discussion_id=discussion.discussion_id,
room_id=discussion.room_id,
objective=discussion.objective,
consensus_reached=discussion.consensus_reached,
confidence=discussion.confidence,
summary=discussion.summary,
action_items=discussion.action_items,
unresolved_issues=discussion.unresolved_issues,
key_decisions=discussion.key_decisions,
total_rounds=discussion.total_rounds,
total_messages=discussion.total_messages,
participating_agents=discussion.participating_agents,
agent_contributions=discussion.agent_contributions,
status=discussion.status,
end_reason=discussion.end_reason,
created_at=discussion.created_at.isoformat(),
completed_at=discussion.completed_at.isoformat() if discussion.completed_at else None
)

View File

@@ -0,0 +1,241 @@
"""
AI接口管理路由
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from loguru import logger
from services.ai_provider_service import AIProviderService
from utils.encryption import mask_api_key
router = APIRouter()
# ============ 请求/响应模型 ============
class ProxyConfigModel(BaseModel):
"""代理配置模型"""
http_proxy: Optional[str] = None
https_proxy: Optional[str] = None
no_proxy: List[str] = []
class RateLimitModel(BaseModel):
"""速率限制模型"""
requests_per_minute: int = 60
tokens_per_minute: int = 100000
class ProviderCreateRequest(BaseModel):
"""创建AI接口请求"""
provider_type: str = Field(..., description="提供商类型: minimax, zhipu, openrouter, kimi, deepseek, gemini, ollama, llmstudio")
name: str = Field(..., description="自定义名称")
model: str = Field(..., description="模型名称")
api_key: str = Field(default="", description="API密钥")
base_url: str = Field(default="", description="API基础URL")
use_proxy: bool = Field(default=False, description="是否使用代理")
proxy_config: Optional[ProxyConfigModel] = None
rate_limit: Optional[RateLimitModel] = None
timeout: int = Field(default=60, description="超时时间(秒)")
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
class Config:
json_schema_extra = {
"example": {
"provider_type": "openrouter",
"name": "OpenRouter GPT-4",
"model": "openai/gpt-4-turbo",
"api_key": "sk-xxx",
"use_proxy": True,
"proxy_config": {
"http_proxy": "http://127.0.0.1:7890",
"https_proxy": "http://127.0.0.1:7890"
}
}
}
class ProviderUpdateRequest(BaseModel):
"""更新AI接口请求"""
name: Optional[str] = None
model: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
use_proxy: Optional[bool] = None
proxy_config: Optional[ProxyConfigModel] = None
rate_limit: Optional[RateLimitModel] = None
timeout: Optional[int] = None
extra_params: Optional[Dict[str, Any]] = None
enabled: Optional[bool] = None
class ProviderResponse(BaseModel):
"""AI接口响应"""
provider_id: str
provider_type: str
name: str
api_key_masked: str
base_url: str
model: str
use_proxy: bool
proxy_config: Dict[str, Any]
rate_limit: Dict[str, int]
timeout: int
extra_params: Dict[str, Any]
enabled: bool
created_at: str
updated_at: str
class TestConfigRequest(BaseModel):
"""测试配置请求"""
provider_type: str
api_key: str
base_url: str = ""
model: str = ""
use_proxy: bool = False
proxy_config: Optional[ProxyConfigModel] = None
timeout: int = 30
class TestResponse(BaseModel):
"""测试响应"""
success: bool
message: str
model: Optional[str] = None
latency_ms: Optional[float] = None
# ============ 路由处理 ============
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
async def create_provider(request: ProviderCreateRequest):
"""
创建新的AI接口配置
"""
try:
provider = await AIProviderService.create_provider(
provider_type=request.provider_type,
name=request.name,
model=request.model,
api_key=request.api_key,
base_url=request.base_url,
use_proxy=request.use_proxy,
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
rate_limit=request.rate_limit.dict() if request.rate_limit else None,
timeout=request.timeout,
extra_params=request.extra_params
)
return _to_response(provider)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"创建AI接口失败: {e}")
raise HTTPException(status_code=500, detail="创建失败")
@router.get("", response_model=List[ProviderResponse])
async def list_providers(enabled_only: bool = False):
"""
获取所有AI接口配置
"""
providers = await AIProviderService.get_all_providers(enabled_only)
return [_to_response(p) for p in providers]
@router.get("/{provider_id}", response_model=ProviderResponse)
async def get_provider(provider_id: str):
"""
获取指定AI接口配置
"""
provider = await AIProviderService.get_provider(provider_id)
if not provider:
raise HTTPException(status_code=404, detail="AI接口不存在")
return _to_response(provider)
@router.put("/{provider_id}", response_model=ProviderResponse)
async def update_provider(provider_id: str, request: ProviderUpdateRequest):
"""
更新AI接口配置
"""
update_data = request.dict(exclude_unset=True)
# 转换嵌套模型
if "proxy_config" in update_data and update_data["proxy_config"]:
update_data["proxy_config"] = update_data["proxy_config"].dict() if hasattr(update_data["proxy_config"], "dict") else update_data["proxy_config"]
if "rate_limit" in update_data and update_data["rate_limit"]:
update_data["rate_limit"] = update_data["rate_limit"].dict() if hasattr(update_data["rate_limit"], "dict") else update_data["rate_limit"]
try:
provider = await AIProviderService.update_provider(provider_id, **update_data)
if not provider:
raise HTTPException(status_code=404, detail="AI接口不存在")
return _to_response(provider)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_provider(provider_id: str):
"""
删除AI接口配置
"""
success = await AIProviderService.delete_provider(provider_id)
if not success:
raise HTTPException(status_code=404, detail="AI接口不存在")
@router.post("/{provider_id}/test", response_model=TestResponse)
async def test_provider(provider_id: str):
"""
测试AI接口连接
"""
result = await AIProviderService.test_provider(provider_id)
return TestResponse(**result)
@router.post("/test", response_model=TestResponse)
async def test_provider_config(request: TestConfigRequest):
"""
测试AI接口配置不保存
"""
result = await AIProviderService.test_provider_config(
provider_type=request.provider_type,
api_key=request.api_key,
base_url=request.base_url,
model=request.model,
use_proxy=request.use_proxy,
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
timeout=request.timeout
)
return TestResponse(**result)
# ============ 辅助函数 ============
def _to_response(provider) -> ProviderResponse:
"""
转换为响应模型
"""
return ProviderResponse(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
name=provider.name,
api_key_masked=mask_api_key(provider.api_key) if provider.api_key else "",
base_url=provider.base_url,
model=provider.model,
use_proxy=provider.use_proxy,
proxy_config=provider.proxy_config,
rate_limit=provider.rate_limit,
timeout=provider.timeout,
extra_params=provider.extra_params,
enabled=provider.enabled,
created_at=provider.created_at.isoformat(),
updated_at=provider.updated_at.isoformat()
)

View File

@@ -0,0 +1,22 @@
"""
业务服务模块
"""
from .ai_provider_service import AIProviderService
from .agent_service import AgentService
from .chatroom_service import ChatRoomService
from .message_router import MessageRouter
from .discussion_engine import DiscussionEngine
from .consensus_manager import ConsensusManager
from .mcp_service import MCPService
from .memory_service import MemoryService
__all__ = [
"AIProviderService",
"AgentService",
"ChatRoomService",
"MessageRouter",
"DiscussionEngine",
"ConsensusManager",
"MCPService",
"MemoryService",
]

View File

@@ -0,0 +1,438 @@
"""
Agent服务
管理AI代理的配置
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.agent import Agent
from services.ai_provider_service import AIProviderService
class AgentService:
"""
Agent服务类
负责Agent的CRUD操作
"""
@classmethod
async def create_agent(
cls,
name: str,
role: str,
system_prompt: str,
provider_id: str,
temperature: float = 0.7,
max_tokens: int = 2000,
capabilities: Optional[Dict[str, Any]] = None,
behavior: Optional[Dict[str, Any]] = None,
avatar: Optional[str] = None,
color: str = "#1890ff"
) -> Agent:
"""
创建新的Agent
Args:
name: Agent名称
role: 角色定义
system_prompt: 系统提示词
provider_id: 使用的AI接口ID
temperature: 温度参数
max_tokens: 最大token数
capabilities: 能力配置
behavior: 行为配置
avatar: 头像URL
color: 代表颜色
Returns:
创建的Agent文档
"""
# 验证AI接口存在
provider = await AIProviderService.get_provider(provider_id)
if not provider:
raise ValueError(f"AI接口不存在: {provider_id}")
# 生成唯一ID
agent_id = f"agent-{uuid.uuid4().hex[:8]}"
# 默认能力配置
default_capabilities = {
"memory_enabled": False,
"mcp_tools": [],
"skills": [],
"multimodal": False
}
if capabilities:
default_capabilities.update(capabilities)
# 默认行为配置
default_behavior = {
"speak_threshold": 0.5,
"max_speak_per_round": 2,
"speak_style": "balanced"
}
if behavior:
default_behavior.update(behavior)
# 创建文档
agent = Agent(
agent_id=agent_id,
name=name,
role=role,
system_prompt=system_prompt,
provider_id=provider_id,
temperature=temperature,
max_tokens=max_tokens,
capabilities=default_capabilities,
behavior=default_behavior,
avatar=avatar,
color=color,
enabled=True,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await agent.insert()
logger.info(f"创建Agent: {agent_id} ({name})")
return agent
@classmethod
async def get_agent(cls, agent_id: str) -> Optional[Agent]:
"""
获取指定Agent
Args:
agent_id: Agent ID
Returns:
Agent文档或None
"""
return await Agent.find_one(Agent.agent_id == agent_id)
@classmethod
async def get_all_agents(
cls,
enabled_only: bool = False
) -> List[Agent]:
"""
获取所有Agent
Args:
enabled_only: 是否只返回启用的Agent
Returns:
Agent列表
"""
if enabled_only:
return await Agent.find(Agent.enabled == True).to_list()
return await Agent.find_all().to_list()
@classmethod
async def get_agents_by_ids(
cls,
agent_ids: List[str]
) -> List[Agent]:
"""
根据ID列表获取多个Agent
Args:
agent_ids: Agent ID列表
Returns:
Agent列表
"""
return await Agent.find(
{"agent_id": {"$in": agent_ids}}
).to_list()
@classmethod
async def update_agent(
cls,
agent_id: str,
**kwargs
) -> Optional[Agent]:
"""
更新Agent配置
Args:
agent_id: Agent ID
**kwargs: 要更新的字段
Returns:
更新后的Agent或None
"""
agent = await cls.get_agent(agent_id)
if not agent:
return None
# 如果更新了provider_id验证其存在
if "provider_id" in kwargs:
provider = await AIProviderService.get_provider(kwargs["provider_id"])
if not provider:
raise ValueError(f"AI接口不存在: {kwargs['provider_id']}")
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(agent, key):
setattr(agent, key, value)
await agent.save()
logger.info(f"更新Agent: {agent_id}")
return agent
@classmethod
async def delete_agent(cls, agent_id: str) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
是否删除成功
"""
agent = await cls.get_agent(agent_id)
if not agent:
return False
await agent.delete()
logger.info(f"删除Agent: {agent_id}")
return True
@classmethod
async def test_agent(
cls,
agent_id: str,
test_message: str = "你好,请简单介绍一下你自己。"
) -> Dict[str, Any]:
"""
测试Agent对话
Args:
agent_id: Agent ID
test_message: 测试消息
Returns:
测试结果
"""
agent = await cls.get_agent(agent_id)
if not agent:
return {
"success": False,
"message": f"Agent不存在: {agent_id}"
}
if not agent.enabled:
return {
"success": False,
"message": "Agent已禁用"
}
# 构建消息
messages = [
{"role": "system", "content": agent.system_prompt},
{"role": "user", "content": test_message}
]
# 调用AI接口
response = await AIProviderService.chat(
provider_id=agent.provider_id,
messages=messages,
temperature=agent.temperature,
max_tokens=agent.max_tokens
)
if response.success:
return {
"success": True,
"message": "测试成功",
"response": response.content,
"model": response.model,
"tokens": response.total_tokens,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error
}
@classmethod
async def duplicate_agent(
cls,
agent_id: str,
new_name: Optional[str] = None
) -> Optional[Agent]:
"""
复制Agent
Args:
agent_id: 源Agent ID
new_name: 新Agent名称
Returns:
新创建的Agent或None
"""
source_agent = await cls.get_agent(agent_id)
if not source_agent:
return None
return await cls.create_agent(
name=new_name or f"{source_agent.name} (副本)",
role=source_agent.role,
system_prompt=source_agent.system_prompt,
provider_id=source_agent.provider_id,
temperature=source_agent.temperature,
max_tokens=source_agent.max_tokens,
capabilities=source_agent.capabilities,
behavior=source_agent.behavior,
avatar=source_agent.avatar,
color=source_agent.color
)
@classmethod
async def generate_system_prompt(
cls,
provider_id: str,
name: str,
role: str,
description: Optional[str] = None
) -> Dict[str, Any]:
"""
使用AI生成Agent系统提示词
Args:
provider_id: AI接口ID
name: Agent名称
role: 角色定位
description: 额外描述(可选)
Returns:
生成结果包含success和生成的prompt
"""
# 验证AI接口存在
provider = await AIProviderService.get_provider(provider_id)
if not provider:
return {
"success": False,
"message": f"AI接口不存在: {provider_id}"
}
# 构建生成提示词的请求
generate_prompt = f"""请为一个AI Agent编写系统提示词system prompt
Agent名称{name}
角色定位:{role}
{f'补充说明:{description}' if description else ''}
要求:
1. 提示词应简洁专业控制在200字以内
2. 明确该Agent的核心职责和专业领域
3. 说明在多Agent讨论中应该关注什么
4. 使用中文编写
5. 不要包含任何问候语或开场白,直接给出提示词内容
请直接输出系统提示词,不要有任何额外的解释或包装。"""
try:
messages = [{"role": "user", "content": generate_prompt}]
response = await AIProviderService.chat(
provider_id=provider_id,
messages=messages,
temperature=0.7,
max_tokens=1000
)
if response.success:
# 清理可能的包装文本
content = response.content.strip()
# 移除可能的markdown代码块标记
if content.startswith("```"):
lines = content.split("\n")
content = "\n".join(lines[1:])
if content.endswith("```"):
content = content[:-3]
content = content.strip()
return {
"success": True,
"prompt": content,
"model": response.model,
"tokens": response.total_tokens
}
else:
return {
"success": False,
"message": response.error or "生成失败"
}
except Exception as e:
logger.error(f"生成系统提示词失败: {e}")
return {
"success": False,
"message": f"生成失败: {str(e)}"
}
# Agent预设模板
AGENT_TEMPLATES = {
"product_manager": {
"name": "产品经理",
"role": "产品规划和需求分析专家",
"system_prompt": """你是一位经验丰富的产品经理,擅长:
- 分析用户需求和痛点
- 制定产品策略和路线图
- 平衡业务目标和用户体验
- 与团队协作推进产品迭代
在讨论中,你需要从产品角度出发,关注用户价值、商业可行性和优先级排序。
请用专业但易懂的语言表达观点。""",
"color": "#1890ff"
},
"developer": {
"name": "开发工程师",
"role": "技术实现和架构设计专家",
"system_prompt": """你是一位资深的软件开发工程师,擅长:
- 系统架构设计
- 代码实现和优化
- 技术方案评估
- 性能和安全考量
在讨论中,你需要从技术角度出发,关注实现可行性、技术债务和最佳实践。
请提供具体的技术建议和潜在风险评估。""",
"color": "#52c41a"
},
"designer": {
"name": "设计师",
"role": "用户体验和界面设计专家",
"system_prompt": """你是一位专业的UI/UX设计师擅长:
- 用户体验设计
- 界面视觉设计
- 交互流程优化
- 设计系统构建
在讨论中,你需要从设计角度出发,关注用户体验、视觉美感和交互流畅性。
请提供设计建议并考虑可用性和一致性。""",
"color": "#eb2f96"
},
"moderator": {
"name": "主持人",
"role": "讨论主持和共识判断专家",
"system_prompt": """你是讨论的主持人,负责:
- 引导讨论方向
- 总结各方观点
- 判断是否达成共识
- 提炼行动要点
在讨论中,你需要保持中立,促进有效沟通,并在适当时机总结讨论成果。
当各方观点趋于一致时,请明确指出并总结共识内容。""",
"color": "#722ed1"
}
}

View File

@@ -0,0 +1,364 @@
"""
AI接口提供商服务
管理AI接口的配置和调用
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.ai_provider import AIProvider
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
from utils.encryption import encrypt_api_key, decrypt_api_key
from utils.rate_limiter import rate_limiter
class AIProviderService:
"""
AI接口提供商服务类
负责AI接口的CRUD操作和调用
"""
# 缓存适配器实例
_adapter_cache: Dict[str, BaseAdapter] = {}
@classmethod
async def create_provider(
cls,
provider_type: str,
name: str,
model: str,
api_key: str = "",
base_url: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
rate_limit: Optional[Dict[str, int]] = None,
timeout: int = 60,
extra_params: Optional[Dict[str, Any]] = None
) -> AIProvider:
"""
创建新的AI接口配置
Args:
provider_type: 提供商类型
name: 自定义名称
model: 模型名称
api_key: API密钥
base_url: API基础URL
use_proxy: 是否使用代理
proxy_config: 代理配置
rate_limit: 速率限制配置
timeout: 超时时间
extra_params: 额外参数
Returns:
创建的AIProvider文档
"""
# 验证提供商类型
try:
get_adapter(provider_type)
except ValueError as e:
raise ValueError(f"不支持的提供商类型: {provider_type}")
# 生成唯一ID
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
# 加密API密钥
encrypted_key = encrypt_api_key(api_key) if api_key else ""
# 创建文档
provider = AIProvider(
provider_id=provider_id,
provider_type=provider_type,
name=name,
api_key=encrypted_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config or {},
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
timeout=timeout,
extra_params=extra_params or {},
enabled=True,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await provider.insert()
# 注册速率限制
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"创建AI接口配置: {provider_id} ({name})")
return provider
@classmethod
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
"""
获取指定AI接口配置
Args:
provider_id: 接口ID
Returns:
AIProvider文档或None
"""
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
@classmethod
async def get_all_providers(
cls,
enabled_only: bool = False
) -> List[AIProvider]:
"""
获取所有AI接口配置
Args:
enabled_only: 是否只返回启用的接口
Returns:
AIProvider列表
"""
if enabled_only:
return await AIProvider.find(AIProvider.enabled == True).to_list()
return await AIProvider.find_all().to_list()
@classmethod
async def update_provider(
cls,
provider_id: str,
**kwargs
) -> Optional[AIProvider]:
"""
更新AI接口配置
Args:
provider_id: 接口ID
**kwargs: 要更新的字段
Returns:
更新后的AIProvider或None
"""
provider = await cls.get_provider(provider_id)
if not provider:
return None
# 如果更新了API密钥需要加密
if "api_key" in kwargs and kwargs["api_key"]:
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(provider, key):
setattr(provider, key, value)
await provider.save()
# 清除适配器缓存
cls._adapter_cache.pop(provider_id, None)
# 更新速率限制
if "rate_limit" in kwargs:
rate_limiter.unregister(provider_id)
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"更新AI接口配置: {provider_id}")
return provider
@classmethod
async def delete_provider(cls, provider_id: str) -> bool:
"""
删除AI接口配置
Args:
provider_id: 接口ID
Returns:
是否删除成功
"""
provider = await cls.get_provider(provider_id)
if not provider:
return False
await provider.delete()
# 清除缓存和速率限制
cls._adapter_cache.pop(provider_id, None)
rate_limiter.unregister(provider_id)
logger.info(f"删除AI接口配置: {provider_id}")
return True
@classmethod
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
"""
获取AI接口的适配器实例
Args:
provider_id: 接口ID
Returns:
适配器实例或None
"""
# 检查缓存
if provider_id in cls._adapter_cache:
return cls._adapter_cache[provider_id]
provider = await cls.get_provider(provider_id)
if not provider or not provider.enabled:
return None
# 解密API密钥
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
# 创建适配器
adapter_class = get_adapter(provider.provider_type)
adapter = adapter_class(
api_key=api_key,
base_url=provider.base_url,
model=provider.model,
use_proxy=provider.use_proxy,
proxy_config=provider.proxy_config,
timeout=provider.timeout,
**provider.extra_params
)
# 缓存适配器
cls._adapter_cache[provider_id] = adapter
return adapter
@classmethod
async def chat(
cls,
provider_id: str,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""
调用AI接口进行对话
Args:
provider_id: 接口ID
messages: 消息列表 [{"role": "user", "content": "..."}]
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Returns:
适配器响应
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return AdapterResponse(
success=False,
error=f"AI接口不存在或未启用: {provider_id}"
)
# 检查速率限制
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
return AdapterResponse(
success=False,
error="请求频率超限,请稍后重试"
)
# 转换消息格式
chat_messages = [
ChatMessage(
role=m.get("role", "user"),
content=m.get("content", ""),
name=m.get("name")
)
for m in messages
]
# 调用适配器
response = await adapter.chat(
messages=chat_messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response
@classmethod
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
"""
测试AI接口连接
Args:
provider_id: 接口ID
Returns:
测试结果
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return {
"success": False,
"message": f"AI接口不存在或未启用: {provider_id}"
}
return await adapter.test_connection()
@classmethod
async def test_provider_config(
cls,
provider_type: str,
api_key: str,
base_url: str = "",
model: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 30,
**kwargs
) -> Dict[str, Any]:
"""
测试AI接口配置不保存
Args:
provider_type: 提供商类型
api_key: API密钥
base_url: API基础URL
model: 模型名称
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间
**kwargs: 额外参数
Returns:
测试结果
"""
try:
adapter_class = get_adapter(provider_type)
except ValueError:
return {
"success": False,
"message": f"不支持的提供商类型: {provider_type}"
}
adapter = adapter_class(
api_key=api_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
return await adapter.test_connection()

View File

@@ -0,0 +1,357 @@
"""
聊天室服务
管理聊天室的创建和状态
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.chatroom import ChatRoom, ChatRoomStatus
from models.message import Message
from services.agent_service import AgentService
class ChatRoomService:
"""
聊天室服务类
负责聊天室的CRUD操作
"""
@classmethod
async def create_chatroom(
cls,
name: str,
description: str = "",
agents: Optional[List[str]] = None,
moderator_agent_id: Optional[str] = None,
config: Optional[Dict[str, Any]] = None
) -> ChatRoom:
"""
创建新的聊天室
Args:
name: 聊天室名称
description: 描述
agents: Agent ID列表
moderator_agent_id: 主持人Agent ID
config: 聊天室配置
Returns:
创建的ChatRoom文档
"""
# 验证Agent存在
if agents:
existing_agents = await AgentService.get_agents_by_ids(agents)
existing_ids = {a.agent_id for a in existing_agents}
missing_ids = set(agents) - existing_ids
if missing_ids:
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
# 验证主持人Agent
if moderator_agent_id:
moderator = await AgentService.get_agent(moderator_agent_id)
if not moderator:
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
# 生成唯一ID
room_id = f"room-{uuid.uuid4().hex[:8]}"
# 默认配置
default_config = {
"max_rounds": 50,
"message_history_size": 20,
"consensus_threshold": 0.8,
"round_interval": 1.0,
"allow_user_interrupt": True
}
if config:
default_config.update(config)
# 创建文档
chatroom = ChatRoom(
room_id=room_id,
name=name,
description=description,
objective="",
agents=agents or [],
moderator_agent_id=moderator_agent_id,
config=default_config,
status=ChatRoomStatus.IDLE.value,
current_round=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await chatroom.insert()
logger.info(f"创建聊天室: {room_id} ({name})")
return chatroom
@classmethod
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
"""
获取指定聊天室
Args:
room_id: 聊天室ID
Returns:
ChatRoom文档或None
"""
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
@classmethod
async def get_all_chatrooms(cls) -> List[ChatRoom]:
"""
获取所有聊天室
Returns:
ChatRoom列表
"""
return await ChatRoom.find_all().to_list()
@classmethod
async def update_chatroom(
cls,
room_id: str,
**kwargs
) -> Optional[ChatRoom]:
"""
更新聊天室配置
Args:
room_id: 聊天室ID
**kwargs: 要更新的字段
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 验证Agent
if "agents" in kwargs:
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
existing_ids = {a.agent_id for a in existing_agents}
missing_ids = set(kwargs["agents"]) - existing_ids
if missing_ids:
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
# 验证主持人
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
if not moderator:
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(chatroom, key):
setattr(chatroom, key, value)
await chatroom.save()
logger.info(f"更新聊天室: {room_id}")
return chatroom
@classmethod
async def delete_chatroom(cls, room_id: str) -> bool:
"""
删除聊天室
Args:
room_id: 聊天室ID
Returns:
是否删除成功
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return False
# 删除相关消息
await Message.find(Message.room_id == room_id).delete()
await chatroom.delete()
logger.info(f"删除聊天室: {room_id}")
return True
@classmethod
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
"""
向聊天室添加Agent
Args:
room_id: 聊天室ID
agent_id: Agent ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 验证Agent存在
agent = await AgentService.get_agent(agent_id)
if not agent:
raise ValueError(f"Agent不存在: {agent_id}")
# 添加Agent
if agent_id not in chatroom.agents:
chatroom.agents.append(agent_id)
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
"""
从聊天室移除Agent
Args:
room_id: 聊天室ID
agent_id: Agent ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 移除Agent
if agent_id in chatroom.agents:
chatroom.agents.remove(agent_id)
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def set_objective(
cls,
room_id: str,
objective: str
) -> Optional[ChatRoom]:
"""
设置讨论目标
Args:
room_id: 聊天室ID
objective: 讨论目标
Returns:
更新后的ChatRoom或None
"""
return await cls.update_chatroom(room_id, objective=objective)
@classmethod
async def update_status(
cls,
room_id: str,
status: ChatRoomStatus
) -> Optional[ChatRoom]:
"""
更新聊天室状态
Args:
room_id: 聊天室ID
status: 新状态
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
chatroom.status = status.value
chatroom.updated_at = datetime.utcnow()
if status == ChatRoomStatus.COMPLETED:
chatroom.completed_at = datetime.utcnow()
await chatroom.save()
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
return chatroom
@classmethod
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
"""
增加轮次计数
Args:
room_id: 聊天室ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
chatroom.current_round += 1
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def get_messages(
cls,
room_id: str,
limit: int = 50,
skip: int = 0,
discussion_id: Optional[str] = None
) -> List[Message]:
"""
获取聊天室消息历史
Args:
room_id: 聊天室ID
limit: 返回数量限制
skip: 跳过数量
discussion_id: 讨论ID可选
Returns:
消息列表
"""
query = {"room_id": room_id}
if discussion_id:
query["discussion_id"] = discussion_id
return await Message.find(query).sort(
"-created_at"
).skip(skip).limit(limit).to_list()
@classmethod
async def get_recent_messages(
cls,
room_id: str,
count: int = 20,
discussion_id: Optional[str] = None
) -> List[Message]:
"""
获取最近的消息
Args:
room_id: 聊天室ID
count: 消息数量
discussion_id: 讨论ID可选
Returns:
消息列表(按时间正序)
"""
messages = await cls.get_messages(
room_id,
limit=count,
discussion_id=discussion_id
)
return list(reversed(messages)) # 返回正序

View File

@@ -0,0 +1,227 @@
"""
共识管理器
判断讨论是否达成共识
"""
import json
from typing import Dict, Any, Optional
from loguru import logger
from models.agent import Agent
from models.chatroom import ChatRoom
from services.ai_provider_service import AIProviderService
class ConsensusManager:
"""
共识管理器
使用主持人Agent判断讨论共识
"""
# 共识判断提示词模板
CONSENSUS_PROMPT = """你是讨论的主持人,负责判断讨论是否达成共识。
讨论目标:{objective}
对话历史:
{history}
请仔细分析对话内容,判断:
1. 参与者是否对核心问题达成一致意见?
2. 是否还有重要分歧未解决?
3. 讨论结果是否足够明确和可执行?
请以JSON格式回复不要包含任何其他文字
{{
"consensus_reached": true或false,
"confidence": 0到1之间的数字,
"summary": "讨论结果摘要,简洁概括达成的共识或当前状态",
"action_items": ["具体的行动项列表"],
"unresolved_issues": ["未解决的问题列表"],
"key_decisions": ["关键决策列表"]
}}
注意:
- consensus_reached为true表示核心问题已有明确结论
- confidence表示你对共识判断的信心程度
- 如果讨论仍有争议或不够深入应该返回false
- action_items应该是具体可执行的任务
- 请确保返回有效的JSON格式"""
@classmethod
async def check_consensus(
cls,
moderator: Agent,
context: "DiscussionContext",
chatroom: ChatRoom
) -> Dict[str, Any]:
"""
检查是否达成共识
Args:
moderator: 主持人Agent
context: 讨论上下文
chatroom: 聊天室
Returns:
共识判断结果
"""
from services.discussion_engine import DiscussionContext
# 构建历史记录
history_text = ""
for msg in context.messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
else:
history_text += f"[系统]: {msg.content}\n\n"
if not history_text:
return {
"consensus_reached": False,
"confidence": 0,
"summary": "讨论尚未开始",
"action_items": [],
"unresolved_issues": [],
"key_decisions": []
}
# 构建提示词
prompt = cls.CONSENSUS_PROMPT.format(
objective=context.objective,
history=history_text
)
try:
# 调用主持人Agent的AI接口
response = await AIProviderService.chat(
provider_id=moderator.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.3, # 使用较低温度以获得更一致的结果
max_tokens=1000
)
if not response.success:
logger.error(f"共识判断失败: {response.error}")
return cls._default_result("AI接口调用失败")
# 解析JSON响应
content = response.content.strip()
# 尝试提取JSON部分
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 尝试提取JSON块
import re
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError:
logger.warning(f"无法解析共识判断结果: {content}")
return cls._default_result("无法解析AI响应")
else:
return cls._default_result("AI响应格式错误")
# 验证和规范化结果
return cls._normalize_result(result)
except Exception as e:
logger.error(f"共识判断异常: {e}")
return cls._default_result(str(e))
@classmethod
async def generate_summary(
cls,
moderator: Agent,
context: "DiscussionContext"
) -> str:
"""
生成讨论摘要
Args:
moderator: 主持人Agent
context: 讨论上下文
Returns:
讨论摘要
"""
from services.discussion_engine import DiscussionContext
# 构建历史记录
history_text = ""
for msg in context.messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
prompt = f"""请为以下讨论生成一份简洁的摘要。
讨论目标:{context.objective}
对话记录:
{history_text}
请提供:
1. 讨论的主要观点和结论
2. 参与者的立场和建议
3. 最终的决策或共识(如果有)
摘要应该简洁明了控制在300字以内。"""
try:
response = await AIProviderService.chat(
provider_id=moderator.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.5,
max_tokens=500
)
if response.success:
return response.content.strip()
else:
return "无法生成摘要"
except Exception as e:
logger.error(f"生成摘要异常: {e}")
return "生成摘要时发生错误"
@classmethod
def _default_result(cls, error: str = "") -> Dict[str, Any]:
"""
返回默认结果
Args:
error: 错误信息
Returns:
默认共识结果
"""
return {
"consensus_reached": False,
"confidence": 0,
"summary": error if error else "共识判断失败",
"action_items": [],
"unresolved_issues": [],
"key_decisions": []
}
@classmethod
def _normalize_result(cls, result: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化共识结果
Args:
result: 原始结果
Returns:
规范化的结果
"""
return {
"consensus_reached": bool(result.get("consensus_reached", False)),
"confidence": max(0, min(1, float(result.get("confidence", 0)))),
"summary": str(result.get("summary", "")),
"action_items": list(result.get("action_items", [])),
"unresolved_issues": list(result.get("unresolved_issues", [])),
"key_decisions": list(result.get("key_decisions", []))
}

View File

@@ -0,0 +1,589 @@
"""
讨论引擎
实现自由讨论的核心逻辑
"""
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from loguru import logger
from models.chatroom import ChatRoom, ChatRoomStatus
from models.agent import Agent
from models.message import Message, MessageType
from models.discussion_result import DiscussionResult
from services.ai_provider_service import AIProviderService
from services.agent_service import AgentService
from services.chatroom_service import ChatRoomService
from services.message_router import MessageRouter
from services.consensus_manager import ConsensusManager
@dataclass
class DiscussionContext:
"""讨论上下文"""
discussion_id: str
room_id: str
objective: str
current_round: int = 0
messages: List[Message] = field(default_factory=list)
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
def add_message(self, message: Message) -> None:
"""添加消息到上下文"""
self.messages.append(message)
if message.agent_id:
self.agent_speak_counts[message.agent_id] = \
self.agent_speak_counts.get(message.agent_id, 0) + 1
def get_recent_messages(self, count: int = 20) -> List[Message]:
"""获取最近的消息"""
return self.messages[-count:] if len(self.messages) > count else self.messages
def get_agent_speak_count(self, agent_id: str) -> int:
"""获取Agent在当前轮次的发言次数"""
return self.agent_speak_counts.get(agent_id, 0)
def reset_round_counts(self) -> None:
"""重置轮次发言计数"""
self.agent_speak_counts.clear()
class DiscussionEngine:
"""
讨论引擎
实现多Agent自由讨论的核心逻辑
"""
# 活跃的讨论: room_id -> DiscussionContext
_active_discussions: Dict[str, DiscussionContext] = {}
# 停止信号
_stop_signals: Dict[str, bool] = {}
@classmethod
async def start_discussion(
cls,
room_id: str,
objective: str
) -> Optional[DiscussionResult]:
"""
启动讨论
Args:
room_id: 聊天室ID
objective: 讨论目标
Returns:
讨论结果
"""
# 获取聊天室
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom:
raise ValueError(f"聊天室不存在: {room_id}")
if not chatroom.agents:
raise ValueError("聊天室没有Agent参与")
if not objective:
raise ValueError("讨论目标不能为空")
# 检查是否已有活跃讨论
if room_id in cls._active_discussions:
raise ValueError("聊天室已有进行中的讨论")
# 创建讨论
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
# 创建讨论结果记录
discussion_result = DiscussionResult(
discussion_id=discussion_id,
room_id=room_id,
objective=objective,
status="in_progress",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await discussion_result.insert()
# 创建讨论上下文
context = DiscussionContext(
discussion_id=discussion_id,
room_id=room_id,
objective=objective
)
cls._active_discussions[room_id] = context
cls._stop_signals[room_id] = False
# 更新聊天室状态
await ChatRoomService.update_chatroom(
room_id,
status=ChatRoomStatus.ACTIVE.value,
objective=objective,
current_discussion_id=discussion_id,
current_round=0
)
# 广播讨论开始
await MessageRouter.broadcast_status(room_id, "discussion_started", {
"discussion_id": discussion_id,
"objective": objective
})
# 发送系统消息
await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=discussion_id,
agent_id=None,
content=f"讨论开始\n\n目标:{objective}",
message_type=MessageType.SYSTEM.value,
round_num=0
)
logger.info(f"讨论开始: {room_id} - {objective}")
# 运行讨论循环
try:
result = await cls._run_discussion_loop(chatroom, context)
return result
except Exception as e:
logger.error(f"讨论异常: {e}")
await cls._handle_discussion_error(room_id, discussion_id, str(e))
raise
finally:
# 清理
cls._active_discussions.pop(room_id, None)
cls._stop_signals.pop(room_id, None)
@classmethod
async def stop_discussion(cls, room_id: str) -> bool:
"""
停止讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
if room_id not in cls._active_discussions:
return False
cls._stop_signals[room_id] = True
logger.info(f"收到停止讨论信号: {room_id}")
return True
@classmethod
async def pause_discussion(cls, room_id: str) -> bool:
"""
暂停讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
if room_id not in cls._active_discussions:
return False
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
await MessageRouter.broadcast_status(room_id, "discussion_paused")
logger.info(f"讨论暂停: {room_id}")
return True
@classmethod
async def resume_discussion(cls, room_id: str) -> bool:
"""
恢复讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
return False
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
logger.info(f"讨论恢复: {room_id}")
return True
@classmethod
async def _run_discussion_loop(
cls,
chatroom: ChatRoom,
context: DiscussionContext
) -> DiscussionResult:
"""
运行讨论循环
Args:
chatroom: 聊天室
context: 讨论上下文
Returns:
讨论结果
"""
room_id = chatroom.room_id
config = chatroom.get_config()
# 获取所有Agent
agents = await AgentService.get_agents_by_ids(chatroom.agents)
agent_map = {a.agent_id: a for a in agents}
# 获取主持人(用于共识判断)
moderator = None
if chatroom.moderator_agent_id:
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
consecutive_no_speak = 0 # 连续无人发言的轮次
while context.current_round < config.max_rounds:
# 检查停止信号
if cls._stop_signals.get(room_id, False):
break
# 检查暂停状态
current_chatroom = await ChatRoomService.get_chatroom(room_id)
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
await asyncio.sleep(1)
continue
# 增加轮次
context.current_round += 1
context.reset_round_counts()
# 广播轮次信息
await MessageRouter.broadcast_round_info(
room_id,
context.current_round,
config.max_rounds
)
# 更新聊天室轮次
await ChatRoomService.update_chatroom(
room_id,
current_round=context.current_round
)
# 本轮是否有人发言
round_has_message = False
# 遍历所有Agent判断是否发言
for agent_id in chatroom.agents:
agent = agent_map.get(agent_id)
if not agent or not agent.enabled:
continue
# 检查本轮发言次数限制
behavior = agent.get_behavior()
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
continue
# 判断是否发言
should_speak, content = await cls._should_agent_speak(
agent, context, chatroom
)
if should_speak and content:
# 广播输入状态
await MessageRouter.broadcast_typing(room_id, agent_id, True)
# 保存并广播消息
message = await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=context.discussion_id,
agent_id=agent_id,
content=content,
message_type=MessageType.TEXT.value,
round_num=context.current_round
)
# 更新上下文
context.add_message(message)
round_has_message = True
# 广播输入结束
await MessageRouter.broadcast_typing(room_id, agent_id, False)
# 轮次间隔
await asyncio.sleep(config.round_interval)
# 检查是否需要共识判断
if round_has_message and moderator:
consecutive_no_speak = 0
# 每隔几轮检查一次共识
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
if consensus_result.get("consensus_reached", False):
confidence = consensus_result.get("confidence", 0)
if confidence >= config.consensus_threshold:
# 达成共识,结束讨论
return await cls._finalize_discussion(
context,
consensus_result,
"consensus"
)
else:
consecutive_no_speak += 1
# 连续多轮无人发言,检查共识或结束
if consecutive_no_speak >= 3:
if moderator:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
return await cls._finalize_discussion(
context,
consensus_result,
"no_more_discussion"
)
else:
return await cls._finalize_discussion(
context,
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
"no_more_discussion"
)
# 达到最大轮次
if moderator:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
else:
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
return await cls._finalize_discussion(
context,
consensus_result,
"max_rounds"
)
@classmethod
async def _should_agent_speak(
cls,
agent: Agent,
context: DiscussionContext,
chatroom: ChatRoom
) -> tuple[bool, str]:
"""
判断Agent是否应该发言
Args:
agent: Agent实例
context: 讨论上下文
chatroom: 聊天室
Returns:
(是否发言, 发言内容)
"""
# 构建判断提示词
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
history_text = ""
for msg in recent_messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
else:
history_text += f"[系统]: {msg.content}\n\n"
prompt = f"""你是{agent.name},角色是{agent.role}
{agent.system_prompt}
当前讨论目标:{context.objective}
对话历史:
{history_text if history_text else "(还没有对话)"}
当前是第{context.current_round}轮讨论。
请根据你的角色判断:
1. 你是否有新的观点或建议要分享?
2. 你是否需要回应其他人的观点?
3. 当前讨论是否需要你的专业意见?
如果你认为需要发言,请直接给出你的发言内容。
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"
注意:
- 请保持发言简洁有力每次发言控制在200字以内
- 避免重复已经说过的内容
- 如果已经达成共识或接近共识可以选择PASS"""
try:
# 调用AI接口
response = await AIProviderService.chat(
provider_id=agent.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=agent.temperature,
max_tokens=agent.max_tokens
)
if not response.success:
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
return False, ""
content = response.content.strip()
# 判断是否PASS
if content.upper() == "PASS" or content.upper().startswith("PASS"):
return False, ""
return True, content
except Exception as e:
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
return False, ""
@classmethod
async def _finalize_discussion(
cls,
context: DiscussionContext,
consensus_result: Dict[str, Any],
end_reason: str
) -> DiscussionResult:
"""
完成讨论,保存结果
Args:
context: 讨论上下文
consensus_result: 共识判断结果
end_reason: 结束原因
Returns:
讨论结果
"""
room_id = context.room_id
# 获取讨论结果记录
discussion_result = await DiscussionResult.find_one(
DiscussionResult.discussion_id == context.discussion_id
)
if discussion_result:
# 更新统计
discussion_result.update_stats(
total_rounds=context.current_round,
total_messages=len(context.messages),
agent_contributions=context.agent_speak_counts
)
# 标记完成
discussion_result.mark_completed(
consensus_reached=consensus_result.get("consensus_reached", False),
confidence=consensus_result.get("confidence", 0),
summary=consensus_result.get("summary", ""),
action_items=consensus_result.get("action_items", []),
unresolved_issues=consensus_result.get("unresolved_issues", []),
end_reason=end_reason
)
await discussion_result.save()
# 更新聊天室状态
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
# 发送系统消息
summary_text = f"""讨论结束
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
置信度:{consensus_result.get("confidence", 0):.0%}
摘要:{consensus_result.get("summary", "")}
行动项:
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or ""}
未解决问题:
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or ""}
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=context.discussion_id,
agent_id=None,
content=summary_text,
message_type=MessageType.SYSTEM.value,
round_num=context.current_round
)
# 广播讨论结束
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
"discussion_id": context.discussion_id,
"consensus_reached": consensus_result.get("consensus_reached", False),
"end_reason": end_reason
})
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
return discussion_result
@classmethod
async def _handle_discussion_error(
cls,
room_id: str,
discussion_id: str,
error: str
) -> None:
"""
处理讨论错误
Args:
room_id: 聊天室ID
discussion_id: 讨论ID
error: 错误信息
"""
# 更新聊天室状态
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
# 更新讨论结果
discussion_result = await DiscussionResult.find_one(
DiscussionResult.discussion_id == discussion_id
)
if discussion_result:
discussion_result.status = "failed"
discussion_result.end_reason = f"error: {error}"
discussion_result.updated_at = datetime.utcnow()
await discussion_result.save()
# 广播错误
await MessageRouter.broadcast_error(room_id, error)
@classmethod
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
"""
获取活跃的讨论上下文
Args:
room_id: 聊天室ID
Returns:
讨论上下文或None
"""
return cls._active_discussions.get(room_id)
@classmethod
def is_discussion_active(cls, room_id: str) -> bool:
"""
检查是否有活跃讨论
Args:
room_id: 聊天室ID
Returns:
是否活跃
"""
return room_id in cls._active_discussions

View File

@@ -0,0 +1,252 @@
"""
MCP服务
管理MCP工具的集成和调用
"""
import json
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
from loguru import logger
class MCPService:
"""
MCP工具服务
集成MCP服务器提供工具调用能力
"""
# MCP服务器配置目录
MCP_CONFIG_DIR = Path(os.getenv("CURSOR_MCP_DIR", "~/.cursor/mcps")).expanduser()
# 已注册的工具: server_name -> List[tool_info]
_registered_tools: Dict[str, List[Dict[str, Any]]] = {}
# Agent工具映射: agent_id -> List[tool_name]
_agent_tools: Dict[str, List[str]] = {}
@classmethod
async def initialize(cls) -> None:
"""
初始化MCP服务
扫描并注册可用的MCP工具
"""
logger.info("初始化MCP服务...")
if not cls.MCP_CONFIG_DIR.exists():
logger.warning(f"MCP配置目录不存在: {cls.MCP_CONFIG_DIR}")
return
# 扫描MCP服务器目录
for server_dir in cls.MCP_CONFIG_DIR.iterdir():
if server_dir.is_dir():
await cls._scan_server(server_dir)
logger.info(f"MCP服务初始化完成已注册 {len(cls._registered_tools)} 个服务器")
@classmethod
async def _scan_server(cls, server_dir: Path) -> None:
"""
扫描MCP服务器目录
Args:
server_dir: 服务器目录
"""
server_name = server_dir.name
tools_dir = server_dir / "tools"
if not tools_dir.exists():
return
tools = []
for tool_file in tools_dir.glob("*.json"):
try:
with open(tool_file, "r", encoding="utf-8") as f:
tool_info = json.load(f)
tool_info["_file"] = str(tool_file)
tools.append(tool_info)
except Exception as e:
logger.warning(f"加载MCP工具配置失败: {tool_file} - {e}")
if tools:
cls._registered_tools[server_name] = tools
logger.debug(f"注册MCP服务器: {server_name}, 工具数: {len(tools)}")
@classmethod
def list_servers(cls) -> List[str]:
"""
列出所有可用的MCP服务器
Returns:
服务器名称列表
"""
return list(cls._registered_tools.keys())
@classmethod
def list_tools(cls, server: Optional[str] = None) -> List[Dict[str, Any]]:
"""
列出可用的MCP工具
Args:
server: 服务器名称(可选,不指定则返回所有)
Returns:
工具信息列表
"""
if server:
return cls._registered_tools.get(server, [])
# 返回所有工具
all_tools = []
for server_name, tools in cls._registered_tools.items():
for tool in tools:
tool_copy = tool.copy()
tool_copy["server"] = server_name
all_tools.append(tool_copy)
return all_tools
@classmethod
def get_tool(cls, server: str, tool_name: str) -> Optional[Dict[str, Any]]:
"""
获取指定工具的信息
Args:
server: 服务器名称
tool_name: 工具名称
Returns:
工具信息或None
"""
tools = cls._registered_tools.get(server, [])
for tool in tools:
if tool.get("name") == tool_name:
return tool
return None
@classmethod
async def call_tool(
cls,
server: str,
tool_name: str,
arguments: Dict[str, Any]
) -> Dict[str, Any]:
"""
调用MCP工具
Args:
server: 服务器名称
tool_name: 工具名称
arguments: 工具参数
Returns:
调用结果
"""
tool = cls.get_tool(server, tool_name)
if not tool:
return {
"success": False,
"error": f"工具不存在: {server}/{tool_name}"
}
# TODO: 实际的MCP工具调用逻辑
# 这里需要根据MCP协议实现工具调用
# 目前返回模拟结果
logger.info(f"调用MCP工具: {server}/{tool_name}, 参数: {arguments}")
return {
"success": True,
"result": f"MCP工具调用: {tool_name}",
"tool": tool_name,
"server": server,
"arguments": arguments
}
@classmethod
def register_tool_for_agent(
cls,
agent_id: str,
tool_name: str
) -> bool:
"""
为Agent注册可用工具
Args:
agent_id: Agent ID
tool_name: 工具名称(格式: server/tool_name
Returns:
是否注册成功
"""
if agent_id not in cls._agent_tools:
cls._agent_tools[agent_id] = []
if tool_name not in cls._agent_tools[agent_id]:
cls._agent_tools[agent_id].append(tool_name)
return True
return False
@classmethod
def unregister_tool_for_agent(
cls,
agent_id: str,
tool_name: str
) -> bool:
"""
为Agent注销工具
Args:
agent_id: Agent ID
tool_name: 工具名称
Returns:
是否注销成功
"""
if agent_id in cls._agent_tools:
if tool_name in cls._agent_tools[agent_id]:
cls._agent_tools[agent_id].remove(tool_name)
return True
return False
@classmethod
def get_agent_tools(cls, agent_id: str) -> List[str]:
"""
获取Agent可用的工具列表
Args:
agent_id: Agent ID
Returns:
工具名称列表
"""
return cls._agent_tools.get(agent_id, [])
@classmethod
def get_tools_for_prompt(cls, agent_id: str) -> str:
"""
获取用于提示词的工具描述
Args:
agent_id: Agent ID
Returns:
工具描述文本
"""
tool_names = cls.get_agent_tools(agent_id)
if not tool_names:
return ""
descriptions = []
for full_name in tool_names:
parts = full_name.split("/", 1)
if len(parts) == 2:
server, tool_name = parts
tool = cls.get_tool(server, tool_name)
if tool:
desc = tool.get("description", "无描述")
descriptions.append(f"- {tool_name}: {desc}")
if not descriptions:
return ""
return "你可以使用以下工具:\n" + "\n".join(descriptions)

View File

@@ -0,0 +1,416 @@
"""
记忆服务
管理Agent的记忆存储和检索
"""
import uuid
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
import numpy as np
from loguru import logger
from models.agent_memory import AgentMemory, MemoryType
class MemoryService:
"""
Agent记忆服务
提供记忆的存储、检索和管理功能
"""
# 嵌入模型(延迟加载)
_embedding_model = None
@classmethod
def _get_embedding_model(cls):
"""
获取嵌入模型实例(延迟加载)
"""
if cls._embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
logger.info("嵌入模型加载成功")
except Exception as e:
logger.warning(f"嵌入模型加载失败: {e}")
return None
return cls._embedding_model
@classmethod
async def create_memory(
cls,
agent_id: str,
content: str,
memory_type: str = MemoryType.SHORT_TERM.value,
importance: float = 0.5,
source_room_id: Optional[str] = None,
source_discussion_id: Optional[str] = None,
tags: Optional[List[str]] = None,
expires_in_hours: Optional[int] = None
) -> AgentMemory:
"""
创建新的记忆
Args:
agent_id: Agent ID
content: 记忆内容
memory_type: 记忆类型
importance: 重要性评分
source_room_id: 来源聊天室
source_discussion_id: 来源讨论
tags: 标签
expires_in_hours: 过期时间(小时)
Returns:
创建的AgentMemory文档
"""
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
# 生成向量嵌入
embedding = await cls._generate_embedding(content)
# 生成摘要
summary = content[:100] + "..." if len(content) > 100 else content
# 计算过期时间
expires_at = None
if expires_in_hours:
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
memory = AgentMemory(
memory_id=memory_id,
agent_id=agent_id,
memory_type=memory_type,
content=content,
summary=summary,
embedding=embedding,
importance=importance,
source_room_id=source_room_id,
source_discussion_id=source_discussion_id,
tags=tags or [],
created_at=datetime.utcnow(),
last_accessed=datetime.utcnow(),
expires_at=expires_at
)
await memory.insert()
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
return memory
@classmethod
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
"""
获取指定记忆
Args:
memory_id: 记忆ID
Returns:
AgentMemory文档或None
"""
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
@classmethod
async def get_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None,
limit: int = 50
) -> List[AgentMemory]:
"""
获取Agent的记忆列表
Args:
agent_id: Agent ID
memory_type: 记忆类型(可选)
limit: 返回数量限制
Returns:
记忆列表
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
return await AgentMemory.find(query).sort(
"-importance", "-last_accessed"
).limit(limit).to_list()
@classmethod
async def search_memories(
cls,
agent_id: str,
query: str,
limit: int = 10,
memory_type: Optional[str] = None,
min_relevance: float = 0.3
) -> List[Dict[str, Any]]:
"""
搜索相关记忆
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型(可选)
min_relevance: 最小相关性阈值
Returns:
带相关性分数的记忆列表
"""
# 生成查询向量
query_embedding = await cls._generate_embedding(query)
if not query_embedding:
# 无法生成向量时,使用文本匹配
return await cls._text_search(agent_id, query, limit, memory_type)
# 获取Agent的所有记忆
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
memories = await AgentMemory.find(filter_query).to_list()
# 计算相似度
results = []
for memory in memories:
if memory.is_expired():
continue
if memory.embedding:
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
relevance = memory.calculate_relevance_score(similarity)
if relevance >= min_relevance:
results.append({
"memory": memory,
"similarity": similarity,
"relevance": relevance
})
# 按相关性排序
results.sort(key=lambda x: x["relevance"], reverse=True)
# 更新访问记录
for item in results[:limit]:
memory = item["memory"]
memory.access()
await memory.save()
return results[:limit]
@classmethod
async def update_memory(
cls,
memory_id: str,
**kwargs
) -> Optional[AgentMemory]:
"""
更新记忆
Args:
memory_id: 记忆ID
**kwargs: 要更新的字段
Returns:
更新后的AgentMemory或None
"""
memory = await cls.get_memory(memory_id)
if not memory:
return None
# 如果更新了内容,重新生成嵌入
if "content" in kwargs:
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
for key, value in kwargs.items():
if hasattr(memory, key):
setattr(memory, key, value)
await memory.save()
return memory
@classmethod
async def delete_memory(cls, memory_id: str) -> bool:
"""
删除记忆
Args:
memory_id: 记忆ID
Returns:
是否删除成功
"""
memory = await cls.get_memory(memory_id)
if not memory:
return False
await memory.delete()
return True
@classmethod
async def delete_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None
) -> int:
"""
删除Agent的记忆
Args:
agent_id: Agent ID
memory_type: 记忆类型(可选)
Returns:
删除的数量
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
result = await AgentMemory.find(query).delete()
return result.deleted_count if result else 0
@classmethod
async def cleanup_expired_memories(cls) -> int:
"""
清理过期的记忆
Returns:
清理的数量
"""
now = datetime.utcnow()
result = await AgentMemory.find(
{"expires_at": {"$lt": now}}
).delete()
count = result.deleted_count if result else 0
if count > 0:
logger.info(f"清理了 {count} 条过期记忆")
return count
@classmethod
async def consolidate_memories(
cls,
agent_id: str,
min_importance: float = 0.7,
max_age_days: int = 30
) -> None:
"""
整合记忆(将重要的短期记忆转为长期记忆)
Args:
agent_id: Agent ID
min_importance: 最小重要性阈值
max_age_days: 最大年龄(天)
"""
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
# 查找符合条件的短期记忆
memories = await AgentMemory.find({
"agent_id": agent_id,
"memory_type": MemoryType.SHORT_TERM.value,
"importance": {"$gte": min_importance},
"created_at": {"$lt": cutoff_date}
}).to_list()
for memory in memories:
memory.memory_type = MemoryType.LONG_TERM.value
memory.expires_at = None # 长期记忆不过期
await memory.save()
if memories:
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
@classmethod
async def _generate_embedding(cls, text: str) -> List[float]:
"""
生成文本的向量嵌入
Args:
text: 文本内容
Returns:
向量嵌入列表
"""
model = cls._get_embedding_model()
if model is None:
return []
try:
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
except Exception as e:
logger.warning(f"生成嵌入失败: {e}")
return []
@classmethod
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
"""
计算余弦相似度
Args:
vec1: 向量1
vec2: 向量2
Returns:
相似度 (0-1)
"""
if not vec1 or not vec2:
return 0.0
try:
a = np.array(vec1)
b = np.array(vec2)
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
return float(max(0, similarity))
except Exception:
return 0.0
@classmethod
async def _text_search(
cls,
agent_id: str,
query: str,
limit: int,
memory_type: Optional[str]
) -> List[Dict[str, Any]]:
"""
文本搜索(后备方案)
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型
Returns:
记忆列表
"""
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
# 简单的文本匹配
memories = await AgentMemory.find(filter_query).to_list()
results = []
query_lower = query.lower()
for memory in memories:
if memory.is_expired():
continue
content_lower = memory.content.lower()
if query_lower in content_lower:
# 计算简单的匹配分数
score = len(query_lower) / len(content_lower)
results.append({
"memory": memory,
"similarity": score,
"relevance": score * memory.importance
})
results.sort(key=lambda x: x["relevance"], reverse=True)
return results[:limit]

View File

@@ -0,0 +1,335 @@
"""
消息路由服务
管理消息的发送和广播
"""
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Callable, Set
from dataclasses import dataclass, field
from loguru import logger
from fastapi import WebSocket
from models.message import Message, MessageType
from models.chatroom import ChatRoom
@dataclass
class WebSocketConnection:
"""WebSocket连接信息"""
websocket: WebSocket
room_id: str
connected_at: datetime = field(default_factory=datetime.utcnow)
class MessageRouter:
"""
消息路由器
管理WebSocket连接和消息广播
"""
# 房间连接映射: room_id -> Set[WebSocket]
_room_connections: Dict[str, Set[WebSocket]] = {}
# 消息回调: 用于外部订阅消息
_message_callbacks: List[Callable] = []
@classmethod
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
"""
建立WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
await websocket.accept()
if room_id not in cls._room_connections:
cls._room_connections[room_id] = set()
cls._room_connections[room_id].add(websocket)
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
@classmethod
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
"""
断开WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
if room_id in cls._room_connections:
cls._room_connections[room_id].discard(websocket)
# 清理空房间
if not cls._room_connections[room_id]:
del cls._room_connections[room_id]
logger.info(f"WebSocket连接断开: {room_id}")
@classmethod
async def broadcast_to_room(
cls,
room_id: str,
message: Dict[str, Any]
) -> None:
"""
向聊天室广播消息
Args:
room_id: 聊天室ID
message: 消息内容
"""
if room_id not in cls._room_connections:
return
# 获取所有连接
connections = cls._room_connections[room_id].copy()
# 并发发送
tasks = []
for websocket in connections:
tasks.append(cls._send_message(room_id, websocket, message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@classmethod
async def _send_message(
cls,
room_id: str,
websocket: WebSocket,
message: Dict[str, Any]
) -> None:
"""
向单个WebSocket发送消息
Args:
room_id: 聊天室ID
websocket: WebSocket实例
message: 消息内容
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
# 移除断开的连接
await cls.disconnect(room_id, websocket)
@classmethod
async def save_and_broadcast_message(
cls,
room_id: str,
discussion_id: str,
agent_id: Optional[str],
content: str,
message_type: str = MessageType.TEXT.value,
round_num: int = 0,
attachments: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
tool_results: Optional[List[Dict[str, Any]]] = None
) -> Message:
"""
保存消息并广播
Args:
room_id: 聊天室ID
discussion_id: 讨论ID
agent_id: 发送Agent ID
content: 消息内容
message_type: 消息类型
round_num: 轮次号
attachments: 附件
tool_calls: 工具调用
tool_results: 工具结果
Returns:
保存的Message文档
"""
# 创建消息
message = Message(
message_id=f"msg-{uuid.uuid4().hex[:12]}",
room_id=room_id,
discussion_id=discussion_id,
agent_id=agent_id,
content=content,
message_type=message_type,
attachments=attachments or [],
round=round_num,
token_count=len(content) // 4, # 粗略估计
tool_calls=tool_calls or [],
tool_results=tool_results or [],
created_at=datetime.utcnow()
)
await message.insert()
# 构建广播消息
broadcast_data = {
"type": "message",
"data": {
"message_id": message.message_id,
"room_id": message.room_id,
"discussion_id": message.discussion_id,
"agent_id": message.agent_id,
"content": message.content,
"message_type": message.message_type,
"round": message.round,
"created_at": message.created_at.isoformat()
}
}
# 广播消息
await cls.broadcast_to_room(room_id, broadcast_data)
# 触发回调
for callback in cls._message_callbacks:
try:
await callback(message)
except Exception as e:
logger.error(f"消息回调执行失败: {e}")
return message
@classmethod
async def broadcast_status(
cls,
room_id: str,
status: str,
data: Optional[Dict[str, Any]] = None
) -> None:
"""
广播状态更新
Args:
room_id: 聊天室ID
status: 状态类型
data: 附加数据
"""
message = {
"type": "status",
"status": status,
"data": data or {},
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_typing(
cls,
room_id: str,
agent_id: str,
is_typing: bool = True
) -> None:
"""
广播Agent输入状态
Args:
room_id: 聊天室ID
agent_id: Agent ID
is_typing: 是否正在输入
"""
message = {
"type": "typing",
"agent_id": agent_id,
"is_typing": is_typing,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_round_info(
cls,
room_id: str,
round_num: int,
total_rounds: int
) -> None:
"""
广播轮次信息
Args:
room_id: 聊天室ID
round_num: 当前轮次
total_rounds: 最大轮次
"""
message = {
"type": "round",
"round": round_num,
"total_rounds": total_rounds,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_error(
cls,
room_id: str,
error: str,
agent_id: Optional[str] = None
) -> None:
"""
广播错误信息
Args:
room_id: 聊天室ID
error: 错误信息
agent_id: 相关Agent ID
"""
message = {
"type": "error",
"error": error,
"agent_id": agent_id,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
def register_callback(cls, callback: Callable) -> None:
"""
注册消息回调
Args:
callback: 回调函数接收Message参数
"""
cls._message_callbacks.append(callback)
@classmethod
def unregister_callback(cls, callback: Callable) -> None:
"""
注销消息回调
Args:
callback: 回调函数
"""
if callback in cls._message_callbacks:
cls._message_callbacks.remove(callback)
@classmethod
def get_connection_count(cls, room_id: str) -> int:
"""
获取房间连接数
Args:
room_id: 聊天室ID
Returns:
连接数
"""
return len(cls._room_connections.get(room_id, set()))
@classmethod
def get_all_room_ids(cls) -> List[str]:
"""
获取所有活跃房间ID
Returns:
房间ID列表
"""
return list(cls._room_connections.keys())

13
backend/utils/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""
工具函数模块
"""
from .encryption import encrypt_api_key, decrypt_api_key
from .proxy_handler import get_http_client
from .rate_limiter import RateLimiter
__all__ = [
"encrypt_api_key",
"decrypt_api_key",
"get_http_client",
"RateLimiter",
]

View File

@@ -0,0 +1,97 @@
"""
加密工具模块
用于API密钥的加密和解密
"""
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from loguru import logger
from config import settings
def _get_fernet() -> Fernet:
"""
获取Fernet加密器实例
使用配置的加密密钥派生加密密钥
Returns:
Fernet加密器
"""
# 使用PBKDF2从密钥派生32字节密钥
salt = b"ai_chatroom_salt" # 固定salt实际生产环境应使用随机salt
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(
kdf.derive(settings.ENCRYPTION_KEY.encode())
)
return Fernet(key)
def encrypt_api_key(api_key: str) -> str:
"""
加密API密钥
Args:
api_key: 原始API密钥
Returns:
加密后的密钥字符串
"""
if not api_key:
return ""
try:
fernet = _get_fernet()
encrypted = fernet.encrypt(api_key.encode())
return encrypted.decode()
except Exception as e:
logger.error(f"API密钥加密失败: {e}")
raise ValueError("加密失败")
def decrypt_api_key(encrypted_key: str) -> str:
"""
解密API密钥
Args:
encrypted_key: 加密的密钥字符串
Returns:
解密后的原始API密钥
"""
if not encrypted_key:
return ""
try:
fernet = _get_fernet()
decrypted = fernet.decrypt(encrypted_key.encode())
return decrypted.decode()
except Exception as e:
logger.error(f"API密钥解密失败: {e}")
raise ValueError("解密失败,密钥可能已损坏或被篡改")
def mask_api_key(api_key: str, visible_chars: int = 4) -> str:
"""
掩码API密钥用于安全显示
Args:
api_key: 原始API密钥
visible_chars: 末尾可见字符数
Returns:
掩码后的密钥 (如: ****abc1)
"""
if not api_key:
return ""
if len(api_key) <= visible_chars:
return "*" * len(api_key)
return "*" * (len(api_key) - visible_chars) + api_key[-visible_chars:]

View File

@@ -0,0 +1,135 @@
"""
代理处理模块
处理HTTP代理配置
"""
from typing import Optional, Dict, Any
import httpx
from loguru import logger
from config import settings
def get_proxy_dict(
use_proxy: bool,
proxy_config: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, str]]:
"""
获取代理配置字典
Args:
use_proxy: 是否使用代理
proxy_config: 代理配置
Returns:
代理配置字典或None
"""
if not use_proxy:
return None
proxies = {}
if proxy_config:
http_proxy = proxy_config.get("http_proxy")
https_proxy = proxy_config.get("https_proxy")
else:
# 使用全局默认代理
http_proxy = settings.DEFAULT_HTTP_PROXY
https_proxy = settings.DEFAULT_HTTPS_PROXY
if http_proxy:
proxies["http://"] = http_proxy
if https_proxy:
proxies["https://"] = https_proxy
return proxies if proxies else None
def get_http_client(
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
) -> httpx.AsyncClient:
"""
获取配置好的HTTP异步客户端
Args:
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间(秒)
**kwargs: 其他httpx参数
Returns:
配置好的httpx.AsyncClient实例
"""
proxies = get_proxy_dict(use_proxy, proxy_config)
client_kwargs = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
**kwargs
}
if proxies:
client_kwargs["proxies"] = proxies
logger.debug(f"HTTP客户端使用代理: {proxies}")
return httpx.AsyncClient(**client_kwargs)
async def test_proxy_connection(
proxy_config: Dict[str, Any],
test_url: str = "https://www.google.com"
) -> Dict[str, Any]:
"""
测试代理连接是否可用
Args:
proxy_config: 代理配置
test_url: 测试URL
Returns:
测试结果字典,包含 success, message, latency_ms
"""
try:
async with get_http_client(
use_proxy=True,
proxy_config=proxy_config,
timeout=10
) as client:
import time
start = time.time()
response = await client.get(test_url)
latency = (time.time() - start) * 1000
if response.status_code == 200:
return {
"success": True,
"message": "代理连接正常",
"latency_ms": round(latency, 2)
}
else:
return {
"success": False,
"message": f"代理返回状态码: {response.status_code}",
"latency_ms": round(latency, 2)
}
except httpx.ProxyError as e:
return {
"success": False,
"message": f"代理连接失败: {str(e)}",
"latency_ms": None
}
except httpx.TimeoutException:
return {
"success": False,
"message": "代理连接超时",
"latency_ms": None
}
except Exception as e:
return {
"success": False,
"message": f"连接错误: {str(e)}",
"latency_ms": None
}

View File

@@ -0,0 +1,233 @@
"""
速率限制器模块
使用令牌桶算法控制请求频率
"""
import asyncio
import time
from typing import Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class TokenBucket:
"""令牌桶"""
capacity: int # 桶容量
tokens: float = field(init=False) # 当前令牌数
refill_rate: float # 每秒填充速率
last_refill: float = field(default_factory=time.time)
def __post_init__(self):
self.tokens = float(self.capacity)
def _refill(self) -> None:
"""填充令牌"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.refill_rate
)
self.last_refill = now
def consume(self, tokens: int = 1) -> bool:
"""
尝试消费令牌
Args:
tokens: 要消费的令牌数
Returns:
是否消费成功
"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def wait_time(self, tokens: int = 1) -> float:
"""
计算需要等待的时间
Args:
tokens: 需要的令牌数
Returns:
需要等待的秒数
"""
self._refill()
if self.tokens >= tokens:
return 0.0
needed = tokens - self.tokens
return needed / self.refill_rate
class RateLimiter:
"""
速率限制器
管理多个提供商的速率限制
"""
def __init__(self):
self._buckets: Dict[str, TokenBucket] = {}
self._locks: Dict[str, asyncio.Lock] = {}
def register(
self,
provider_id: str,
requests_per_minute: int = 60,
tokens_per_minute: int = 100000
) -> None:
"""
注册提供商的速率限制
Args:
provider_id: 提供商ID
requests_per_minute: 每分钟请求数
tokens_per_minute: 每分钟token数
"""
# 请求限制桶
self._buckets[f"{provider_id}:requests"] = TokenBucket(
capacity=requests_per_minute,
refill_rate=requests_per_minute / 60.0
)
# Token限制桶
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
capacity=tokens_per_minute,
refill_rate=tokens_per_minute / 60.0
)
# 创建锁
self._locks[provider_id] = asyncio.Lock()
logger.debug(
f"注册速率限制: {provider_id} - "
f"{requests_per_minute}请求/分钟, "
f"{tokens_per_minute}tokens/分钟"
)
def unregister(self, provider_id: str) -> None:
"""
取消注册提供商的速率限制
Args:
provider_id: 提供商ID
"""
self._buckets.pop(f"{provider_id}:requests", None)
self._buckets.pop(f"{provider_id}:tokens", None)
self._locks.pop(provider_id, None)
async def acquire(
self,
provider_id: str,
estimated_tokens: int = 1
) -> bool:
"""
获取请求许可(非阻塞)
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
# 未注册,默认允许
return True
lock = self._locks.get(provider_id)
if lock:
async with lock:
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
return True
return False
async def acquire_wait(
self,
provider_id: str,
estimated_tokens: int = 1,
max_wait: float = 60.0
) -> bool:
"""
获取请求许可(阻塞等待)
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
max_wait: 最大等待时间(秒)
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return True
lock = self._locks.get(provider_id)
if not lock:
return True
start_time = time.time()
while True:
async with lock:
# 计算需要等待的时间
request_wait = request_bucket.wait_time(1)
token_wait = token_bucket.wait_time(estimated_tokens)
wait_time = max(request_wait, token_wait)
if wait_time == 0:
request_bucket.consume(1)
token_bucket.consume(estimated_tokens)
return True
# 检查是否超时
elapsed = time.time() - start_time
if elapsed + wait_time > max_wait:
logger.warning(
f"速率限制等待超时: {provider_id}, "
f"需要等待{wait_time:.2f}"
)
return False
# 在锁外等待
await asyncio.sleep(min(wait_time, 1.0))
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
"""
获取提供商的速率限制状态
Args:
provider_id: 提供商ID
Returns:
状态字典
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return None
request_bucket._refill()
token_bucket._refill()
return {
"requests_remaining": int(request_bucket.tokens),
"requests_capacity": request_bucket.capacity,
"tokens_remaining": int(token_bucket.tokens),
"tokens_capacity": token_bucket.capacity
}
# 全局速率限制器实例
rate_limiter = RateLimiter()