feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
26
.claude/settings.local.json
Normal file
26
.claude/settings.local.json
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(powershell:*)",
|
||||||
|
"Bash(where:*)",
|
||||||
|
"Bash(docker:*)",
|
||||||
|
"Bash(docker-compose:*)",
|
||||||
|
"Bash(npm --version)",
|
||||||
|
"Bash(mongosh:*)",
|
||||||
|
"Bash(mongo:*)",
|
||||||
|
"mcp__web-search-prime__webSearchPrime",
|
||||||
|
"Bash(python:*)",
|
||||||
|
"Bash(set HTTP_PROXY=)",
|
||||||
|
"Bash(set HTTPS_PROXY=)",
|
||||||
|
"Bash(export:*)",
|
||||||
|
"Bash(export HTTPS_PROXY=\"\")",
|
||||||
|
"Bash(curl:*)",
|
||||||
|
"Bash(taskkill:*)",
|
||||||
|
"Bash(cmd.exe /c \"tasklist | findstr /i \"\"28876 53736 52620 24000 52468 18560 37980 51192 python\"\"\")",
|
||||||
|
"Bash(cmd.exe:*)",
|
||||||
|
"Bash(git init:*)",
|
||||||
|
"Bash(git config:*)",
|
||||||
|
"Bash(git add:*)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
15
.env.example
Normal file
15
.env.example
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# AI聊天室环境变量配置
|
||||||
|
|
||||||
|
# MongoDB配置
|
||||||
|
MONGO_PASSWORD=chatroom123
|
||||||
|
|
||||||
|
# 安全配置
|
||||||
|
SECRET_KEY=your-secret-key-change-in-production
|
||||||
|
ENCRYPTION_KEY=your-32-byte-encryption-key-here
|
||||||
|
|
||||||
|
# 开发模式
|
||||||
|
DEBUG=false
|
||||||
|
|
||||||
|
# 可选:默认代理配置
|
||||||
|
DEFAULT_HTTP_PROXY=
|
||||||
|
DEFAULT_HTTPS_PROXY=
|
||||||
47
.gitignore
vendored
Normal file
47
.gitignore
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
*.pyc
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Node
|
||||||
|
node_modules/
|
||||||
|
npm-debug.log
|
||||||
|
yarn-error.log
|
||||||
|
dist/
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
99
CLAUDE.md
Normal file
99
CLAUDE.md
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
AI Chat Room (AI聊天室) is a multi-agent collaborative discussion platform. Users configure AI providers, create agents with different roles, and let them discuss in chat rooms to reach consensus.
|
||||||
|
|
||||||
|
**Stack:** FastAPI (Python 3.11+) backend + React 18 (TypeScript) frontend + MongoDB database. Real-time communication via WebSockets.
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
### Docker (Recommended)
|
||||||
|
```bash
|
||||||
|
# Start all services
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# Rebuild after changes
|
||||||
|
docker-compose up -d --build
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
docker-compose logs -f backend
|
||||||
|
docker-compose logs -f frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend (Local)
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
python -m venv venv
|
||||||
|
venv\Scripts\activate # Windows: activate, Linux/Mac: source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Backend runs on http://localhost:8000 - API docs at http://localhost:8000/docs
|
||||||
|
|
||||||
|
### Frontend (Local)
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
npm run dev # Development server (Vite)
|
||||||
|
npm run build # Production build (tsc && vite build)
|
||||||
|
```
|
||||||
|
|
||||||
|
Frontend runs on http://localhost:3000
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Backend Structure
|
||||||
|
- **[adapters/](backend/adapters/)** - AI provider integrations using Adapter pattern
|
||||||
|
- `base_adapter.py` - Abstract base class with `ChatMessage`, `AdapterResponse`, `BaseAdapter`
|
||||||
|
- Each adapter implements `chat()`, `chat_stream()`, `test_connection()`
|
||||||
|
- Supported: OpenRouter, Zhipu (智谱), MiniMax, Kimi, DeepSeek, Gemini, Ollama, LLM Studio
|
||||||
|
- **[models/](backend/models/)** - Beanie ODM documents for MongoDB
|
||||||
|
- **[services/](backend/services/)** - Business logic layer
|
||||||
|
- `discussion_engine.py` - Core multi-agent discussion orchestration
|
||||||
|
- `consensus_manager.py` - Moderator agent evaluates if consensus reached
|
||||||
|
- `message_router.py` - WebSocket message routing
|
||||||
|
- **[routers/](backend/routers/)** - FastAPI route handlers (providers, agents, chatrooms, discussions)
|
||||||
|
- **[utils/](backend/utils/)** - encryption.py (API keys), proxy_handler.py, rate_limiter.py
|
||||||
|
|
||||||
|
### Frontend Structure
|
||||||
|
- **[src/stores/](frontend/src/stores/)** - Zustand state management
|
||||||
|
- **[src/services/](frontend/src/services/)** - API client and WebSocket client
|
||||||
|
- **[src/pages/](frontend/src/pages/)** - Dashboard, ProviderConfig, AgentManagement, ChatRoom, DiscussionHistory
|
||||||
|
- **[src/components/](frontend/src/components/)** - Reusable UI components using Ant Design
|
||||||
|
|
||||||
|
### Key Data Flow
|
||||||
|
1. User creates agents (role + system prompt) and assigns AI providers
|
||||||
|
2. Chat room created with selected agents + optional moderator
|
||||||
|
3. Discussion started: `discussion_engine.py` orchestrates turn-based agent interactions
|
||||||
|
4. Each round: agents receive context and decide whether to speak (role relevance)
|
||||||
|
5. Moderator agent periodically checks for consensus via `consensus_manager.py`
|
||||||
|
6. WebSocket streams messages in real-time to frontend
|
||||||
|
|
||||||
|
### Adding New AI Providers
|
||||||
|
1. Create new adapter in `backend/adapters/` inheriting from `BaseAdapter`
|
||||||
|
2. Implement async methods: `chat()`, `chat_stream()`, `test_connection()`
|
||||||
|
3. Register in `backend/adapters/__init__.py`
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Environment variables in `.env`:
|
||||||
|
- `MONGODB_URL` - MongoDB connection string
|
||||||
|
- `MONGODB_DB` - Database name (default: ai_chatroom)
|
||||||
|
- `SECRET_KEY` - JWT signing key
|
||||||
|
- `ENCRYPTION_KEY` - 32-byte key for API key encryption
|
||||||
|
- `DEFAULT_HTTP_PROXY` / `DEFAULT_HTTPS_PROXY` - Proxy for overseas APIs
|
||||||
|
|
||||||
|
Backend config in [backend/config.py](backend/config.py) - Pydantic Settings with defaults.
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- All async/await - Python async functions throughout backend
|
||||||
|
- API keys encrypted at rest using `cryptography` Fernet
|
||||||
|
- WebSocket heartbeat every 30s (`WS_HEARTBEAT_INTERVAL`)
|
||||||
|
- CORS origins configured in settings for local development
|
||||||
|
- MongoDB indexes created automatically by Beanie on startup
|
||||||
|
- Chinese language UI (README and comments in Chinese)
|
||||||
207
README.md
Normal file
207
README.md
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# AI聊天室 - 多Agent协作讨论平台
|
||||||
|
|
||||||
|
一个基于多AI Agent协作的聊天室系统,支持配置多种AI接口,创建具有不同角色的Agent,并让它们在聊天室中针对特定目标进行自由讨论,最终达成共识。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
### 1. AI接口管理
|
||||||
|
- 支持多种AI提供商:MiniMax、智谱、OpenRouter、Kimi、DeepSeek、Gemini
|
||||||
|
- 支持本地模型:Ollama、LLM Studio
|
||||||
|
- 可配置代理(支持海外API访问)
|
||||||
|
- API密钥加密存储
|
||||||
|
- 速率限制和连接测试
|
||||||
|
|
||||||
|
### 2. Agent管理
|
||||||
|
- 创建自定义Agent角色
|
||||||
|
- 配置系统提示词和行为参数
|
||||||
|
- 支持记忆功能、MCP工具、多模态输入
|
||||||
|
- 预设模板快速创建(产品经理、开发工程师、设计师等)
|
||||||
|
- Agent测试对话功能
|
||||||
|
|
||||||
|
### 3. 聊天室
|
||||||
|
- 创建讨论聊天室,添加多个Agent参与
|
||||||
|
- 设置讨论目标,启动自由讨论
|
||||||
|
- Agent根据角色自主判断是否发言
|
||||||
|
- 实时WebSocket消息推送
|
||||||
|
- 支持暂停、恢复、停止讨论
|
||||||
|
|
||||||
|
### 4. 共识判断
|
||||||
|
- 主持人Agent判断讨论是否达成共识
|
||||||
|
- 自动生成讨论摘要和行动项
|
||||||
|
- 记录未解决的问题
|
||||||
|
- 讨论历史回放
|
||||||
|
|
||||||
|
## 技术架构
|
||||||
|
|
||||||
|
```
|
||||||
|
├── backend/ # FastAPI后端
|
||||||
|
│ ├── adapters/ # AI接口适配器
|
||||||
|
│ ├── models/ # 数据模型
|
||||||
|
│ ├── services/ # 业务逻辑
|
||||||
|
│ ├── routers/ # API路由
|
||||||
|
│ └── utils/ # 工具函数
|
||||||
|
├── frontend/ # React前端
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── components/ # UI组件
|
||||||
|
│ │ ├── pages/ # 页面
|
||||||
|
│ │ ├── stores/ # 状态管理
|
||||||
|
│ │ └── services/ # API服务
|
||||||
|
└── docker-compose.yml # Docker编排
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 使用Docker(推荐)
|
||||||
|
|
||||||
|
1. 克隆项目并进入目录:
|
||||||
|
```bash
|
||||||
|
cd AIChatRoom
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 复制环境变量配置:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 文件,修改密钥等配置
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 启动服务:
|
||||||
|
```bash
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 访问应用:
|
||||||
|
- 前端:http://localhost:3000
|
||||||
|
- 后端API:http://localhost:8000
|
||||||
|
- API文档:http://localhost:8000/docs
|
||||||
|
|
||||||
|
### 本地开发
|
||||||
|
|
||||||
|
#### 后端
|
||||||
|
|
||||||
|
1. 创建Python虚拟环境:
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 安装依赖:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 配置环境变量:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 文件
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 启动MongoDB:
|
||||||
|
```bash
|
||||||
|
# 确保MongoDB已安装并运行
|
||||||
|
mongod
|
||||||
|
```
|
||||||
|
|
||||||
|
5. 启动后端:
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 前端
|
||||||
|
|
||||||
|
1. 安装依赖:
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动开发服务器:
|
||||||
|
```bash
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用指南
|
||||||
|
|
||||||
|
### 1. 配置AI接口
|
||||||
|
|
||||||
|
1. 进入"AI接口配置"页面
|
||||||
|
2. 点击"添加接口"
|
||||||
|
3. 选择提供商类型,填写API密钥和模型名称
|
||||||
|
4. 如需代理,开启"使用代理"并填写代理地址
|
||||||
|
5. 点击测试验证连接
|
||||||
|
|
||||||
|
### 2. 创建Agent
|
||||||
|
|
||||||
|
1. 进入"Agent管理"页面
|
||||||
|
2. 使用预设模板或点击"创建Agent"
|
||||||
|
3. 填写名称、角色、系统提示词
|
||||||
|
4. 选择关联的AI接口
|
||||||
|
5. 配置行为参数(温度、发言倾向等)
|
||||||
|
|
||||||
|
### 3. 开始讨论
|
||||||
|
|
||||||
|
1. 进入"聊天室"页面
|
||||||
|
2. 创建新聊天室,选择参与的Agent
|
||||||
|
3. 可指定主持人Agent用于共识判断
|
||||||
|
4. 点击"开始讨论",输入讨论目标
|
||||||
|
5. 观察Agent自由讨论,等待达成共识
|
||||||
|
|
||||||
|
## API文档
|
||||||
|
|
||||||
|
启动后端后访问 http://localhost:8000/docs 查看完整的Swagger API文档。
|
||||||
|
|
||||||
|
### 主要端点
|
||||||
|
|
||||||
|
- `POST /api/providers` - 创建AI接口配置
|
||||||
|
- `POST /api/agents` - 创建Agent
|
||||||
|
- `POST /api/chatrooms` - 创建聊天室
|
||||||
|
- `POST /api/chatrooms/{id}/start` - 启动讨论
|
||||||
|
- `WS /api/chatrooms/ws/{id}` - 聊天室WebSocket
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
### 环境变量
|
||||||
|
|
||||||
|
| 变量 | 说明 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| MONGODB_URL | MongoDB连接地址 | mongodb://localhost:27017 |
|
||||||
|
| MONGODB_DB | 数据库名称 | ai_chatroom |
|
||||||
|
| SECRET_KEY | 应用密钥 | - |
|
||||||
|
| ENCRYPTION_KEY | 加密密钥(32字节) | - |
|
||||||
|
| DEFAULT_HTTP_PROXY | 默认HTTP代理 | - |
|
||||||
|
| DEFAULT_HTTPS_PROXY | 默认HTTPS代理 | - |
|
||||||
|
|
||||||
|
### 代理配置
|
||||||
|
|
||||||
|
对于需要代理访问的AI接口(如Gemini、OpenRouter),在接口配置中开启代理:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"use_proxy": true,
|
||||||
|
"proxy_config": {
|
||||||
|
"http_proxy": "http://127.0.0.1:7890",
|
||||||
|
"https_proxy": "http://127.0.0.1:7890"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 扩展开发
|
||||||
|
|
||||||
|
### 添加新的AI提供商
|
||||||
|
|
||||||
|
1. 在 `backend/adapters/` 创建新的适配器文件
|
||||||
|
2. 继承 `BaseAdapter` 类
|
||||||
|
3. 实现 `chat`、`chat_stream`、`test_connection` 方法
|
||||||
|
4. 在 `adapters/__init__.py` 注册新适配器
|
||||||
|
|
||||||
|
### 自定义共识判断逻辑
|
||||||
|
|
||||||
|
修改 `backend/services/consensus_manager.py` 中的 `CONSENSUS_PROMPT` 模板,调整共识判断的标准。
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
## 贡献
|
||||||
|
|
||||||
|
欢迎提交Issue和Pull Request!
|
||||||
16
backend/.env.example
Normal file
16
backend/.env.example
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# MongoDB配置
|
||||||
|
MONGODB_URL=mongodb://localhost:27017
|
||||||
|
MONGODB_DB=ai_chatroom
|
||||||
|
|
||||||
|
# 服务配置
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
DEBUG=true
|
||||||
|
|
||||||
|
# 安全配置(生产环境请修改)
|
||||||
|
SECRET_KEY=your-secret-key-change-in-production
|
||||||
|
ENCRYPTION_KEY=your-encryption-key-32-bytes-long
|
||||||
|
|
||||||
|
# 代理配置(可选)
|
||||||
|
# DEFAULT_HTTP_PROXY=http://127.0.0.1:7890
|
||||||
|
# DEFAULT_HTTPS_PROXY=http://127.0.0.1:7890
|
||||||
25
backend/Dockerfile
Normal file
25
backend/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# AI聊天室后端 Dockerfile
|
||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装系统依赖
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 复制依赖文件
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# 安装Python依赖
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 复制应用代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# 暴露端口
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# 启动命令
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
58
backend/adapters/__init__.py
Normal file
58
backend/adapters/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
AI接口适配器模块
|
||||||
|
提供统一的AI调用接口
|
||||||
|
"""
|
||||||
|
from .base_adapter import BaseAdapter, AdapterResponse, ChatMessage
|
||||||
|
from .minimax_adapter import MiniMaxAdapter
|
||||||
|
from .zhipu_adapter import ZhipuAdapter
|
||||||
|
from .openrouter_adapter import OpenRouterAdapter
|
||||||
|
from .kimi_adapter import KimiAdapter
|
||||||
|
from .deepseek_adapter import DeepSeekAdapter
|
||||||
|
from .gemini_adapter import GeminiAdapter
|
||||||
|
from .ollama_adapter import OllamaAdapter
|
||||||
|
from .llmstudio_adapter import LLMStudioAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAdapter",
|
||||||
|
"AdapterResponse",
|
||||||
|
"ChatMessage",
|
||||||
|
"MiniMaxAdapter",
|
||||||
|
"ZhipuAdapter",
|
||||||
|
"OpenRouterAdapter",
|
||||||
|
"KimiAdapter",
|
||||||
|
"DeepSeekAdapter",
|
||||||
|
"GeminiAdapter",
|
||||||
|
"OllamaAdapter",
|
||||||
|
"LLMStudioAdapter",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 适配器注册表
|
||||||
|
ADAPTER_REGISTRY = {
|
||||||
|
"minimax": MiniMaxAdapter,
|
||||||
|
"zhipu": ZhipuAdapter,
|
||||||
|
"openrouter": OpenRouterAdapter,
|
||||||
|
"kimi": KimiAdapter,
|
||||||
|
"deepseek": DeepSeekAdapter,
|
||||||
|
"gemini": GeminiAdapter,
|
||||||
|
"ollama": OllamaAdapter,
|
||||||
|
"llmstudio": LLMStudioAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter(provider_type: str) -> type:
|
||||||
|
"""
|
||||||
|
根据提供商类型获取对应的适配器类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: 提供商类型标识
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
适配器类
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 未知的提供商类型
|
||||||
|
"""
|
||||||
|
adapter_class = ADAPTER_REGISTRY.get(provider_type.lower())
|
||||||
|
if not adapter_class:
|
||||||
|
raise ValueError(f"未知的AI提供商类型: {provider_type}")
|
||||||
|
return adapter_class
|
||||||
166
backend/adapters/base_adapter.py
Normal file
166
backend/adapters/base_adapter.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
AI适配器基类
|
||||||
|
定义统一的AI调用接口
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""聊天消息"""
|
||||||
|
role: str # system, user, assistant
|
||||||
|
content: str # 消息内容
|
||||||
|
name: Optional[str] = None # 发送者名称(可选)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
d = {"role": self.role, "content": self.content}
|
||||||
|
if self.name:
|
||||||
|
d["name"] = self.name
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterResponse:
|
||||||
|
"""适配器响应"""
|
||||||
|
success: bool # 是否成功
|
||||||
|
content: str = "" # 响应内容
|
||||||
|
error: Optional[str] = None # 错误信息
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
prompt_tokens: int = 0 # 输入token数
|
||||||
|
completion_tokens: int = 0 # 输出token数
|
||||||
|
total_tokens: int = 0 # 总token数
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
model: str = "" # 使用的模型
|
||||||
|
finish_reason: str = "" # 结束原因
|
||||||
|
latency_ms: float = 0.0 # 延迟(毫秒)
|
||||||
|
|
||||||
|
# 工具调用结果
|
||||||
|
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.total_tokens == 0:
|
||||||
|
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAdapter(ABC):
|
||||||
|
"""
|
||||||
|
AI适配器基类
|
||||||
|
所有AI提供商适配器必须继承此类
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化适配器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥
|
||||||
|
base_url: API基础URL
|
||||||
|
model: 模型名称
|
||||||
|
use_proxy: 是否使用代理
|
||||||
|
proxy_config: 代理配置
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
**kwargs: 额外参数
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model = model
|
||||||
|
self.use_proxy = use_proxy
|
||||||
|
self.proxy_config = proxy_config or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.extra_params = kwargs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""
|
||||||
|
发送聊天请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大token数
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
适配器响应
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
发送流式聊天请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大token数
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
响应内容片段
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试API连接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果字典,包含 success, message, latency_ms
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_messages(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建消息列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: ChatMessage列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典格式的消息列表
|
||||||
|
"""
|
||||||
|
return [msg.to_dict() for msg in messages]
|
||||||
|
|
||||||
|
def _calculate_latency(self, start_time: datetime) -> float:
|
||||||
|
"""
|
||||||
|
计算延迟
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: 开始时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
延迟毫秒数
|
||||||
|
"""
|
||||||
|
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||||
197
backend/adapters/deepseek_adapter.py
Normal file
197
backend/adapters/deepseek_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
DeepSeek适配器
|
||||||
|
支持DeepSeek大模型API
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
DeepSeek API适配器
|
||||||
|
兼容OpenAI API格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "deepseek-chat",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"DeepSeek API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=choice.get("finish_reason", ""),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time),
|
||||||
|
tool_calls=message.get("tool_calls", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DeepSeek请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DeepSeek流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
250
backend/adapters/gemini_adapter.py
Normal file
250
backend/adapters/gemini_adapter.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
Gemini适配器
|
||||||
|
支持Google Gemini大模型API
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
Google Gemini API适配器
|
||||||
|
使用Gemini的原生API格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "gemini-1.5-pro",
|
||||||
|
use_proxy: bool = True, # Gemini通常需要代理
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_messages_to_gemini(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
) -> tuple[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
将消息转换为Gemini格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 标准消息列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(system_instruction, contents)
|
||||||
|
"""
|
||||||
|
system_instruction = ""
|
||||||
|
contents = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "system":
|
||||||
|
system_instruction += msg.content + "\n"
|
||||||
|
else:
|
||||||
|
role = "user" if msg.role == "user" else "model"
|
||||||
|
contents.append({
|
||||||
|
"role": role,
|
||||||
|
"parts": [{"text": msg.content}]
|
||||||
|
})
|
||||||
|
|
||||||
|
return system_instruction.strip(), contents
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"maxOutputTokens": max_tokens,
|
||||||
|
"topP": kwargs.get("top_p", 0.95),
|
||||||
|
"topK": kwargs.get("top_k", 40)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加系统指令
|
||||||
|
if system_instruction:
|
||||||
|
payload["systemInstruction"] = {
|
||||||
|
"parts": [{"text": system_instruction}]
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 检查是否有候选回复
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if not candidates:
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error="没有生成回复",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate = candidates[0]
|
||||||
|
content = candidate.get("content", {})
|
||||||
|
parts = content.get("parts", [])
|
||||||
|
text = "".join(part.get("text", "") for part in parts)
|
||||||
|
|
||||||
|
# 获取token使用情况
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=text,
|
||||||
|
model=self.model,
|
||||||
|
finish_reason=candidate.get("finishReason", ""),
|
||||||
|
prompt_tokens=usage.get("promptTokenCount", 0),
|
||||||
|
completion_tokens=usage.get("candidatesTokenCount", 0),
|
||||||
|
total_tokens=usage.get("totalTokenCount", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gemini请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"maxOutputTokens": max_tokens,
|
||||||
|
"topP": kwargs.get("top_p", 0.95),
|
||||||
|
"topK": kwargs.get("top_k", 40)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_instruction:
|
||||||
|
payload["systemInstruction"] = {
|
||||||
|
"parts": [{"text": system_instruction}]
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
content = candidates[0].get("content", {})
|
||||||
|
parts = content.get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
text = part.get("text", "")
|
||||||
|
if text:
|
||||||
|
yield text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gemini流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
197
backend/adapters/kimi_adapter.py
Normal file
197
backend/adapters/kimi_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
Kimi适配器
|
||||||
|
支持月之暗面Kimi大模型API
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
Kimi API适配器
|
||||||
|
兼容OpenAI API格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "moonshot-v1-8k",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"Kimi API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=choice.get("finish_reason", ""),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time),
|
||||||
|
tool_calls=message.get("tool_calls", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Kimi请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Kimi流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
253
backend/adapters/llmstudio_adapter.py
Normal file
253
backend/adapters/llmstudio_adapter.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""
|
||||||
|
LLM Studio适配器
|
||||||
|
支持本地LLM Studio服务
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class LLMStudioAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
LLM Studio API适配器
|
||||||
|
兼容OpenAI API格式的本地服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "lm-studio", # LLM Studio使用固定key
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "local-model",
|
||||||
|
use_proxy: bool = False, # 本地服务不需要代理
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 120, # 本地模型可能需要更长时间
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=choice.get("finish_reason", ""),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM Studio请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM Studio流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 首先检查服务是否在运行
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=10
|
||||||
|
) as client:
|
||||||
|
# 获取模型列表
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/models",
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "LLM Studio服务未运行或不可访问",
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
models = [m.get("id", "") for m in data.get("data", [])]
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "LLM Studio中没有加载的模型",
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送测试消息
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def list_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
列出LLM Studio中加载的模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型信息列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=10
|
||||||
|
) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/models",
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return data.get("data", [])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取LLM Studio模型列表失败: {e}")
|
||||||
|
|
||||||
|
return []
|
||||||
251
backend/adapters/minimax_adapter.py
Normal file
251
backend/adapters/minimax_adapter.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""
|
||||||
|
MiniMax适配器
|
||||||
|
支持MiniMax大模型API
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
MiniMax API适配器
|
||||||
|
支持abab系列模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://api.minimax.chat/v1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "abab6.5-chat",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
# MiniMax需要group_id
|
||||||
|
self.group_id = kwargs.get("group_id", "")
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# MiniMax使用特殊的消息格式
|
||||||
|
minimax_messages = []
|
||||||
|
bot_setting = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "system":
|
||||||
|
bot_setting.append({
|
||||||
|
"bot_name": "assistant",
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
minimax_messages.append({
|
||||||
|
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||||
|
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||||
|
"text": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": minimax_messages,
|
||||||
|
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||||
|
"temperature": temperature,
|
||||||
|
"tokens_to_generate": max_tokens,
|
||||||
|
"mask_sensitive_info": False,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||||
|
if self.group_id:
|
||||||
|
url = f"{url}?GroupId={self.group_id}"
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"MiniMax API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 检查API返回的错误
|
||||||
|
if data.get("base_resp", {}).get("status_code", 0) != 0:
|
||||||
|
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {error_msg}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
reply = data.get("reply", "")
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=reply,
|
||||||
|
model=self.model,
|
||||||
|
finish_reason=data.get("output_sensitive", False) and "content_filter" or "stop",
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MiniMax请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
minimax_messages = []
|
||||||
|
bot_setting = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "system":
|
||||||
|
bot_setting.append({
|
||||||
|
"bot_name": "assistant",
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
minimax_messages.append({
|
||||||
|
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||||
|
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||||
|
"text": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": minimax_messages,
|
||||||
|
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||||
|
"temperature": temperature,
|
||||||
|
"tokens_to_generate": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||||
|
if self.group_id:
|
||||||
|
url = f"{url}?GroupId={self.group_id}"
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MiniMax流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
241
backend/adapters/ollama_adapter.py
Normal file
241
backend/adapters/ollama_adapter.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""
|
||||||
|
Ollama适配器
|
||||||
|
支持本地Ollama服务
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
Ollama API适配器
|
||||||
|
用于连接本地Ollama服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "http://localhost:11434"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "", # Ollama通常不需要API密钥
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "llama2",
|
||||||
|
use_proxy: bool = False, # 本地服务通常不需要代理
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 120, # 本地模型可能需要更长时间
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"options": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"num_predict": max_tokens,
|
||||||
|
},
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/api/chat",
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"Ollama API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
message = data.get("message", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=data.get("done_reason", "stop"),
|
||||||
|
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||||
|
completion_tokens=data.get("eval_count", 0),
|
||||||
|
total_tokens=(
|
||||||
|
data.get("prompt_eval_count", 0) +
|
||||||
|
data.get("eval_count", 0)
|
||||||
|
),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ollama请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"options": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"num_predict": max_tokens,
|
||||||
|
},
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/api/chat",
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
message = data.get("message", {})
|
||||||
|
content = message.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
|
||||||
|
# 检查是否完成
|
||||||
|
if data.get("done", False):
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ollama流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 首先检查服务是否在运行
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=10
|
||||||
|
) as client:
|
||||||
|
# 检查模型是否存在
|
||||||
|
response = await client.get(f"{self.base_url}/api/tags")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "Ollama服务未运行或不可访问",
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
||||||
|
|
||||||
|
model_name = self.model.split(":")[0]
|
||||||
|
if model_name not in models:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}",
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送测试消息
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
列出本地可用的模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型名称列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=10
|
||||||
|
) as client:
|
||||||
|
response = await client.get(f"{self.base_url}/api/tags")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return [m.get("name", "") for m in data.get("models", [])]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取Ollama模型列表失败: {e}")
|
||||||
|
|
||||||
|
return []
|
||||||
201
backend/adapters/openrouter_adapter.py
Normal file
201
backend/adapters/openrouter_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
OpenRouter适配器
|
||||||
|
支持通过OpenRouter访问多种AI模型
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
OpenRouter API适配器
|
||||||
|
兼容OpenAI API格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "openai/gpt-4-turbo",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||||
|
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"OpenRouter API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=choice.get("finish_reason", ""),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time),
|
||||||
|
tool_calls=message.get("tool_calls", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenRouter请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||||
|
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenRouter流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
197
backend/adapters/zhipu_adapter.py
Normal file
197
backend/adapters/zhipu_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
智谱AI适配器
|
||||||
|
支持智谱GLM系列模型
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.proxy_handler import get_http_client
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
智谱AI API适配器
|
||||||
|
支持GLM-4、GLM-3等模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "glm-4",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""发送聊天请求"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
logger.error(f"智谱API错误: {response.status_code} - {error_text}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"API错误: {response.status_code} - {error_text}",
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
success=True,
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
finish_reason=choice.get("finish_reason", ""),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
latency_ms=self._calculate_latency(start_time),
|
||||||
|
tool_calls=message.get("tool_calls", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"智谱API请求异常: {e}")
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
latency_ms=self._calculate_latency(start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""发送流式聊天请求"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=self.use_proxy,
|
||||||
|
proxy_config=self.proxy_config,
|
||||||
|
timeout=self.timeout
|
||||||
|
) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self._build_messages(messages),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"智谱流式请求异常: {e}")
|
||||||
|
yield f"[错误: {str(e)}]"
|
||||||
|
|
||||||
|
async def test_connection(self) -> Dict[str, Any]:
|
||||||
|
"""测试API连接"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_messages = [
|
||||||
|
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self.chat(
|
||||||
|
messages=test_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "连接成功",
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e),
|
||||||
|
"latency_ms": self._calculate_latency(start_time)
|
||||||
|
}
|
||||||
50
backend/config.py
Normal file
50
backend/config.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
应用配置模块
|
||||||
|
从环境变量加载配置项
|
||||||
|
"""
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""应用配置类"""
|
||||||
|
|
||||||
|
# MongoDB配置
|
||||||
|
MONGODB_URL: str = "mongodb://localhost:27017"
|
||||||
|
MONGODB_DB: str = "ai_chatroom"
|
||||||
|
|
||||||
|
# 服务配置
|
||||||
|
HOST: str = "0.0.0.0"
|
||||||
|
PORT: int = 8000
|
||||||
|
DEBUG: bool = True
|
||||||
|
|
||||||
|
# 安全配置
|
||||||
|
SECRET_KEY: str = "your-secret-key-change-in-production"
|
||||||
|
ENCRYPTION_KEY: str = "your-encryption-key-32-bytes-long"
|
||||||
|
|
||||||
|
# CORS配置
|
||||||
|
CORS_ORIGINS: list = ["http://localhost:3000", "http://127.0.0.1:3000"]
|
||||||
|
|
||||||
|
# WebSocket配置
|
||||||
|
WS_HEARTBEAT_INTERVAL: int = 30
|
||||||
|
|
||||||
|
# 默认AI配置
|
||||||
|
DEFAULT_TIMEOUT: int = 60
|
||||||
|
DEFAULT_MAX_TOKENS: int = 2000
|
||||||
|
DEFAULT_TEMPERATURE: float = 0.7
|
||||||
|
|
||||||
|
# 代理配置(全局默认)
|
||||||
|
DEFAULT_HTTP_PROXY: Optional[str] = None
|
||||||
|
DEFAULT_HTTPS_PROXY: Optional[str] = None
|
||||||
|
|
||||||
|
# 速率限制
|
||||||
|
RATE_LIMIT_REQUESTS: int = 100
|
||||||
|
RATE_LIMIT_PERIOD: int = 60 # 秒
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
# 全局配置实例
|
||||||
|
settings = Settings()
|
||||||
10
backend/database/__init__.py
Normal file
10
backend/database/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
数据库模块
|
||||||
|
"""
|
||||||
|
from .connection import connect_db, close_db, get_database
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"connect_db",
|
||||||
|
"close_db",
|
||||||
|
"get_database",
|
||||||
|
]
|
||||||
94
backend/database/connection.py
Normal file
94
backend/database/connection.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
MongoDB数据库连接模块
|
||||||
|
使用Motor异步驱动
|
||||||
|
"""
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||||
|
from beanie import init_beanie
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
# 全局数据库客户端和数据库实例
|
||||||
|
_client: Optional[AsyncIOMotorClient] = None
|
||||||
|
_database: Optional[AsyncIOMotorDatabase] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_db() -> None:
|
||||||
|
"""
|
||||||
|
连接MongoDB数据库
|
||||||
|
初始化Beanie ODM
|
||||||
|
"""
|
||||||
|
global _client, _database
|
||||||
|
|
||||||
|
try:
|
||||||
|
_client = AsyncIOMotorClient(settings.MONGODB_URL)
|
||||||
|
_database = _client[settings.MONGODB_DB]
|
||||||
|
|
||||||
|
# 导入所有文档模型用于初始化Beanie
|
||||||
|
from models.ai_provider import AIProvider
|
||||||
|
from models.agent import Agent
|
||||||
|
from models.chatroom import ChatRoom
|
||||||
|
from models.message import Message
|
||||||
|
from models.discussion_result import DiscussionResult
|
||||||
|
from models.agent_memory import AgentMemory
|
||||||
|
|
||||||
|
# 初始化Beanie
|
||||||
|
await init_beanie(
|
||||||
|
database=_database,
|
||||||
|
document_models=[
|
||||||
|
AIProvider,
|
||||||
|
Agent,
|
||||||
|
ChatRoom,
|
||||||
|
Message,
|
||||||
|
DiscussionResult,
|
||||||
|
AgentMemory,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"已连接到MongoDB数据库: {settings.MONGODB_DB}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def close_db() -> None:
|
||||||
|
"""
|
||||||
|
关闭数据库连接
|
||||||
|
"""
|
||||||
|
global _client
|
||||||
|
|
||||||
|
if _client:
|
||||||
|
_client.close()
|
||||||
|
logger.info("数据库连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def get_database() -> AsyncIOMotorDatabase:
|
||||||
|
"""
|
||||||
|
获取数据库实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MongoDB数据库实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 数据库未初始化
|
||||||
|
"""
|
||||||
|
if _database is None:
|
||||||
|
raise RuntimeError("数据库未初始化,请先调用connect_db()")
|
||||||
|
return _database
|
||||||
|
|
||||||
|
|
||||||
|
def get_client() -> AsyncIOMotorClient:
|
||||||
|
"""
|
||||||
|
获取数据库客户端
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MongoDB客户端实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 客户端未初始化
|
||||||
|
"""
|
||||||
|
if _client is None:
|
||||||
|
raise RuntimeError("数据库客户端未初始化")
|
||||||
|
return _client
|
||||||
73
backend/main.py
Normal file
73
backend/main.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
AI聊天室后端主入口
|
||||||
|
FastAPI应用启动文件
|
||||||
|
"""
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from database.connection import connect_db, close_db
|
||||||
|
from routers import providers, agents, chatrooms, discussions
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
应用生命周期管理
|
||||||
|
启动时连接数据库,关闭时断开连接
|
||||||
|
"""
|
||||||
|
logger.info("正在启动AI聊天室服务...")
|
||||||
|
await connect_db()
|
||||||
|
logger.info("数据库连接成功")
|
||||||
|
yield
|
||||||
|
logger.info("正在关闭AI聊天室服务...")
|
||||||
|
await close_db()
|
||||||
|
logger.info("服务已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="AI聊天室",
|
||||||
|
description="多Agent协作讨论平台",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(providers.router, prefix="/api/providers", tags=["AI接口管理"])
|
||||||
|
app.include_router(agents.router, prefix="/api/agents", tags=["Agent管理"])
|
||||||
|
app.include_router(chatrooms.router, prefix="/api/chatrooms", tags=["聊天室管理"])
|
||||||
|
app.include_router(discussions.router, prefix="/api/discussions", tags=["讨论结果"])
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""根路径健康检查"""
|
||||||
|
return {"message": "AI聊天室服务运行中", "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查接口"""
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(
|
||||||
|
"main:app",
|
||||||
|
host=settings.HOST,
|
||||||
|
port=settings.PORT,
|
||||||
|
reload=settings.DEBUG
|
||||||
|
)
|
||||||
25
backend/models/__init__.py
Normal file
25
backend/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
数据模型模块
|
||||||
|
"""
|
||||||
|
from .ai_provider import AIProvider, ProxyConfig, RateLimit
|
||||||
|
from .agent import Agent, AgentCapabilities, AgentBehavior
|
||||||
|
from .chatroom import ChatRoom, ChatRoomConfig
|
||||||
|
from .message import Message, MessageType
|
||||||
|
from .discussion_result import DiscussionResult
|
||||||
|
from .agent_memory import AgentMemory, MemoryType
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AIProvider",
|
||||||
|
"ProxyConfig",
|
||||||
|
"RateLimit",
|
||||||
|
"Agent",
|
||||||
|
"AgentCapabilities",
|
||||||
|
"AgentBehavior",
|
||||||
|
"ChatRoom",
|
||||||
|
"ChatRoomConfig",
|
||||||
|
"Message",
|
||||||
|
"MessageType",
|
||||||
|
"DiscussionResult",
|
||||||
|
"AgentMemory",
|
||||||
|
"MemoryType",
|
||||||
|
]
|
||||||
168
backend/models/agent.py
Normal file
168
backend/models/agent.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Agent数据模型
|
||||||
|
定义AI聊天代理的配置结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCapabilities:
|
||||||
|
"""Agent能力配置"""
|
||||||
|
memory_enabled: bool = False # 是否启用记忆
|
||||||
|
mcp_tools: List[str] = [] # 可用的MCP工具
|
||||||
|
skills: List[str] = [] # 可用的技能
|
||||||
|
multimodal: bool = False # 是否支持多模态
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
memory_enabled: bool = False,
|
||||||
|
mcp_tools: Optional[List[str]] = None,
|
||||||
|
skills: Optional[List[str]] = None,
|
||||||
|
multimodal: bool = False
|
||||||
|
):
|
||||||
|
self.memory_enabled = memory_enabled
|
||||||
|
self.mcp_tools = mcp_tools or []
|
||||||
|
self.skills = skills or []
|
||||||
|
self.multimodal = multimodal
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"memory_enabled": self.memory_enabled,
|
||||||
|
"mcp_tools": self.mcp_tools,
|
||||||
|
"skills": self.skills,
|
||||||
|
"multimodal": self.multimodal
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentCapabilities":
|
||||||
|
"""从字典创建"""
|
||||||
|
if not data:
|
||||||
|
return cls()
|
||||||
|
return cls(
|
||||||
|
memory_enabled=data.get("memory_enabled", False),
|
||||||
|
mcp_tools=data.get("mcp_tools", []),
|
||||||
|
skills=data.get("skills", []),
|
||||||
|
multimodal=data.get("multimodal", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentBehavior:
|
||||||
|
"""Agent行为配置"""
|
||||||
|
speak_threshold: float = 0.5 # 发言阈值(判断是否需要发言)
|
||||||
|
max_speak_per_round: int = 2 # 每轮最多发言次数
|
||||||
|
speak_style: str = "balanced" # 发言风格: concise, balanced, detailed
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
speak_threshold: float = 0.5,
|
||||||
|
max_speak_per_round: int = 2,
|
||||||
|
speak_style: str = "balanced"
|
||||||
|
):
|
||||||
|
self.speak_threshold = speak_threshold
|
||||||
|
self.max_speak_per_round = max_speak_per_round
|
||||||
|
self.speak_style = speak_style
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"speak_threshold": self.speak_threshold,
|
||||||
|
"max_speak_per_round": self.max_speak_per_round,
|
||||||
|
"speak_style": self.speak_style
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentBehavior":
|
||||||
|
"""从字典创建"""
|
||||||
|
if not data:
|
||||||
|
return cls()
|
||||||
|
return cls(
|
||||||
|
speak_threshold=data.get("speak_threshold", 0.5),
|
||||||
|
max_speak_per_round=data.get("max_speak_per_round", 2),
|
||||||
|
speak_style=data.get("speak_style", "balanced")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Agent(Document):
|
||||||
|
"""
|
||||||
|
Agent文档模型
|
||||||
|
存储AI代理的配置信息
|
||||||
|
"""
|
||||||
|
agent_id: str = Field(..., description="唯一标识")
|
||||||
|
name: str = Field(..., description="Agent名称")
|
||||||
|
role: str = Field(..., description="角色定义")
|
||||||
|
system_prompt: str = Field(..., description="系统提示词")
|
||||||
|
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
|
||||||
|
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
|
||||||
|
|
||||||
|
# 能力配置
|
||||||
|
capabilities: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"memory_enabled": False,
|
||||||
|
"mcp_tools": [],
|
||||||
|
"skills": [],
|
||||||
|
"multimodal": False
|
||||||
|
},
|
||||||
|
description="能力配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 行为配置
|
||||||
|
behavior: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"speak_threshold": 0.5,
|
||||||
|
"max_speak_per_round": 2,
|
||||||
|
"speak_style": "balanced"
|
||||||
|
},
|
||||||
|
description="行为配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 外观配置
|
||||||
|
avatar: Optional[str] = Field(default=None, description="头像URL")
|
||||||
|
color: str = Field(default="#1890ff", description="代表颜色")
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
enabled: bool = Field(default=True, description="是否启用")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "agents"
|
||||||
|
|
||||||
|
def get_capabilities(self) -> AgentCapabilities:
|
||||||
|
"""获取能力配置对象"""
|
||||||
|
return AgentCapabilities.from_dict(self.capabilities)
|
||||||
|
|
||||||
|
def get_behavior(self) -> AgentBehavior:
|
||||||
|
"""获取行为配置对象"""
|
||||||
|
return AgentBehavior.from_dict(self.behavior)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"agent_id": "product-manager",
|
||||||
|
"name": "产品经理",
|
||||||
|
"role": "产品规划和需求分析专家",
|
||||||
|
"system_prompt": "你是一位经验丰富的产品经理,擅长分析用户需求...",
|
||||||
|
"provider_id": "openrouter-gpt4",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 2000,
|
||||||
|
"capabilities": {
|
||||||
|
"memory_enabled": True,
|
||||||
|
"mcp_tools": ["web_search"],
|
||||||
|
"skills": [],
|
||||||
|
"multimodal": False
|
||||||
|
},
|
||||||
|
"behavior": {
|
||||||
|
"speak_threshold": 0.5,
|
||||||
|
"max_speak_per_round": 2,
|
||||||
|
"speak_style": "balanced"
|
||||||
|
},
|
||||||
|
"avatar": "https://example.com/avatar.png",
|
||||||
|
"color": "#1890ff"
|
||||||
|
}
|
||||||
|
}
|
||||||
123
backend/models/agent_memory.py
Normal file
123
backend/models/agent_memory.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Agent记忆数据模型
|
||||||
|
定义Agent的记忆存储结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryType(str, Enum):
|
||||||
|
"""记忆类型枚举"""
|
||||||
|
SHORT_TERM = "short_term" # 短期记忆(会话内)
|
||||||
|
LONG_TERM = "long_term" # 长期记忆(跨会话)
|
||||||
|
EPISODIC = "episodic" # 情景记忆(特定事件)
|
||||||
|
SEMANTIC = "semantic" # 语义记忆(知识性)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMemory(Document):
|
||||||
|
"""
|
||||||
|
Agent记忆文档模型
|
||||||
|
存储Agent的记忆内容
|
||||||
|
"""
|
||||||
|
memory_id: str = Field(..., description="唯一标识")
|
||||||
|
agent_id: str = Field(..., description="Agent ID")
|
||||||
|
|
||||||
|
# 记忆内容
|
||||||
|
memory_type: str = Field(
|
||||||
|
default=MemoryType.SHORT_TERM.value,
|
||||||
|
description="记忆类型"
|
||||||
|
)
|
||||||
|
content: str = Field(..., description="记忆内容")
|
||||||
|
summary: str = Field(default="", description="内容摘要")
|
||||||
|
|
||||||
|
# 向量嵌入(用于相似度检索)
|
||||||
|
embedding: List[float] = Field(default_factory=list, description="向量嵌入")
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
importance: float = Field(default=0.5, ge=0, le=1, description="重要性评分")
|
||||||
|
access_count: int = Field(default=0, description="访问次数")
|
||||||
|
|
||||||
|
# 关联信息
|
||||||
|
source_room_id: Optional[str] = Field(default=None, description="来源聊天室ID")
|
||||||
|
source_discussion_id: Optional[str] = Field(default=None, description="来源讨论ID")
|
||||||
|
related_agents: List[str] = Field(default_factory=list, description="相关Agent列表")
|
||||||
|
tags: List[str] = Field(default_factory=list, description="标签")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
last_accessed: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
expires_at: Optional[datetime] = Field(default=None, description="过期时间")
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "agent_memories"
|
||||||
|
indexes = [
|
||||||
|
[("agent_id", 1)],
|
||||||
|
[("memory_type", 1)],
|
||||||
|
[("importance", -1)],
|
||||||
|
[("last_accessed", -1)],
|
||||||
|
]
|
||||||
|
|
||||||
|
def access(self) -> None:
|
||||||
|
"""
|
||||||
|
记录访问,更新访问计数和时间
|
||||||
|
"""
|
||||||
|
self.access_count += 1
|
||||||
|
self.last_accessed = datetime.utcnow()
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""
|
||||||
|
检查记忆是否已过期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否过期
|
||||||
|
"""
|
||||||
|
if self.expires_at is None:
|
||||||
|
return False
|
||||||
|
return datetime.utcnow() > self.expires_at
|
||||||
|
|
||||||
|
def calculate_relevance_score(
|
||||||
|
self,
|
||||||
|
similarity: float,
|
||||||
|
time_decay_factor: float = 0.1
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
计算综合相关性分数
|
||||||
|
结合向量相似度、重要性和时间衰减
|
||||||
|
|
||||||
|
Args:
|
||||||
|
similarity: 向量相似度 (0-1)
|
||||||
|
time_decay_factor: 时间衰减因子
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
综合相关性分数
|
||||||
|
"""
|
||||||
|
# 计算时间衰减
|
||||||
|
hours_since_access = (datetime.utcnow() - self.last_accessed).total_seconds() / 3600
|
||||||
|
time_decay = 1.0 / (1.0 + time_decay_factor * hours_since_access)
|
||||||
|
|
||||||
|
# 综合评分
|
||||||
|
score = (
|
||||||
|
0.5 * similarity +
|
||||||
|
0.3 * self.importance +
|
||||||
|
0.2 * time_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
return min(1.0, max(0.0, score))
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"memory_id": "mem-001",
|
||||||
|
"agent_id": "product-manager",
|
||||||
|
"memory_type": "long_term",
|
||||||
|
"content": "在登录系统设计讨论中,团队决定采用OAuth2.0方案",
|
||||||
|
"summary": "登录系统采用OAuth2.0",
|
||||||
|
"importance": 0.8,
|
||||||
|
"access_count": 5,
|
||||||
|
"source_room_id": "product-design-room",
|
||||||
|
"tags": ["登录", "OAuth", "认证"]
|
||||||
|
}
|
||||||
|
}
|
||||||
149
backend/models/ai_provider.py
Normal file
149
backend/models/ai_provider.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
AI接口提供商数据模型
|
||||||
|
定义AI服务配置结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(str, Enum):
|
||||||
|
"""AI提供商类型枚举"""
|
||||||
|
MINIMAX = "minimax"
|
||||||
|
ZHIPU = "zhipu"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
|
KIMI = "kimi"
|
||||||
|
DEEPSEEK = "deepseek"
|
||||||
|
GEMINI = "gemini"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
LLMSTUDIO = "llmstudio"
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConfig:
|
||||||
|
"""代理配置"""
|
||||||
|
http_proxy: Optional[str] = None # HTTP代理地址
|
||||||
|
https_proxy: Optional[str] = None # HTTPS代理地址
|
||||||
|
no_proxy: List[str] = [] # 不使用代理的域名列表
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
http_proxy: Optional[str] = None,
|
||||||
|
https_proxy: Optional[str] = None,
|
||||||
|
no_proxy: Optional[List[str]] = None
|
||||||
|
):
|
||||||
|
self.http_proxy = http_proxy
|
||||||
|
self.https_proxy = https_proxy
|
||||||
|
self.no_proxy = no_proxy or []
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"http_proxy": self.http_proxy,
|
||||||
|
"https_proxy": self.https_proxy,
|
||||||
|
"no_proxy": self.no_proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ProxyConfig":
|
||||||
|
"""从字典创建"""
|
||||||
|
if not data:
|
||||||
|
return cls()
|
||||||
|
return cls(
|
||||||
|
http_proxy=data.get("http_proxy"),
|
||||||
|
https_proxy=data.get("https_proxy"),
|
||||||
|
no_proxy=data.get("no_proxy", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
"""速率限制配置"""
|
||||||
|
requests_per_minute: int = 60 # 每分钟请求数
|
||||||
|
tokens_per_minute: int = 100000 # 每分钟token数
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
requests_per_minute: int = 60,
|
||||||
|
tokens_per_minute: int = 100000
|
||||||
|
):
|
||||||
|
self.requests_per_minute = requests_per_minute
|
||||||
|
self.tokens_per_minute = tokens_per_minute
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, int]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"requests_per_minute": self.requests_per_minute,
|
||||||
|
"tokens_per_minute": self.tokens_per_minute
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, int]) -> "RateLimit":
|
||||||
|
"""从字典创建"""
|
||||||
|
if not data:
|
||||||
|
return cls()
|
||||||
|
return cls(
|
||||||
|
requests_per_minute=data.get("requests_per_minute", 60),
|
||||||
|
tokens_per_minute=data.get("tokens_per_minute", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(Document):
|
||||||
|
"""
|
||||||
|
AI接口提供商文档模型
|
||||||
|
存储各AI服务的配置信息
|
||||||
|
"""
|
||||||
|
provider_id: str = Field(..., description="唯一标识")
|
||||||
|
provider_type: str = Field(..., description="提供商类型: minimax, zhipu等")
|
||||||
|
name: str = Field(..., description="自定义名称")
|
||||||
|
api_key: str = Field(default="", description="API密钥(加密存储)")
|
||||||
|
base_url: str = Field(default="", description="API基础URL")
|
||||||
|
model: str = Field(..., description="使用的模型名称")
|
||||||
|
|
||||||
|
# 代理配置
|
||||||
|
use_proxy: bool = Field(default=False, description="是否使用代理")
|
||||||
|
proxy_config: Dict[str, Any] = Field(default_factory=dict, description="代理配置")
|
||||||
|
|
||||||
|
# 速率限制
|
||||||
|
rate_limit: Dict[str, int] = Field(
|
||||||
|
default_factory=lambda: {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||||
|
description="速率限制配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 其他配置
|
||||||
|
timeout: int = Field(default=60, description="超时时间(秒)")
|
||||||
|
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
enabled: bool = Field(default=True, description="是否启用")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "ai_providers"
|
||||||
|
|
||||||
|
def get_proxy_config(self) -> ProxyConfig:
|
||||||
|
"""获取代理配置对象"""
|
||||||
|
return ProxyConfig.from_dict(self.proxy_config)
|
||||||
|
|
||||||
|
def get_rate_limit(self) -> RateLimit:
|
||||||
|
"""获取速率限制配置对象"""
|
||||||
|
return RateLimit.from_dict(self.rate_limit)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"provider_id": "openrouter-gpt4",
|
||||||
|
"provider_type": "openrouter",
|
||||||
|
"name": "OpenRouter GPT-4",
|
||||||
|
"api_key": "sk-xxx",
|
||||||
|
"base_url": "https://openrouter.ai/api/v1",
|
||||||
|
"model": "openai/gpt-4-turbo",
|
||||||
|
"use_proxy": True,
|
||||||
|
"proxy_config": {
|
||||||
|
"http_proxy": "http://127.0.0.1:7890",
|
||||||
|
"https_proxy": "http://127.0.0.1:7890"
|
||||||
|
},
|
||||||
|
"timeout": 60
|
||||||
|
}
|
||||||
|
}
|
||||||
131
backend/models/chatroom.py
Normal file
131
backend/models/chatroom.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
聊天室数据模型
|
||||||
|
定义讨论聊天室的配置结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomStatus(str, Enum):
|
||||||
|
"""聊天室状态枚举"""
|
||||||
|
IDLE = "idle" # 空闲,等待开始
|
||||||
|
ACTIVE = "active" # 讨论进行中
|
||||||
|
PAUSED = "paused" # 暂停
|
||||||
|
COMPLETED = "completed" # 已完成
|
||||||
|
ERROR = "error" # 出错
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomConfig:
|
||||||
|
"""聊天室配置"""
|
||||||
|
max_rounds: int = 50 # 最大轮数(备用终止条件)
|
||||||
|
message_history_size: int = 20 # 上下文消息数
|
||||||
|
consensus_threshold: float = 0.8 # 共识阈值
|
||||||
|
round_interval: float = 1.0 # 轮次间隔(秒)
|
||||||
|
allow_user_interrupt: bool = True # 允许用户中断
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_rounds: int = 50,
|
||||||
|
message_history_size: int = 20,
|
||||||
|
consensus_threshold: float = 0.8,
|
||||||
|
round_interval: float = 1.0,
|
||||||
|
allow_user_interrupt: bool = True
|
||||||
|
):
|
||||||
|
self.max_rounds = max_rounds
|
||||||
|
self.message_history_size = message_history_size
|
||||||
|
self.consensus_threshold = consensus_threshold
|
||||||
|
self.round_interval = round_interval
|
||||||
|
self.allow_user_interrupt = allow_user_interrupt
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"max_rounds": self.max_rounds,
|
||||||
|
"message_history_size": self.message_history_size,
|
||||||
|
"consensus_threshold": self.consensus_threshold,
|
||||||
|
"round_interval": self.round_interval,
|
||||||
|
"allow_user_interrupt": self.allow_user_interrupt
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ChatRoomConfig":
|
||||||
|
"""从字典创建"""
|
||||||
|
if not data:
|
||||||
|
return cls()
|
||||||
|
return cls(
|
||||||
|
max_rounds=data.get("max_rounds", 50),
|
||||||
|
message_history_size=data.get("message_history_size", 20),
|
||||||
|
consensus_threshold=data.get("consensus_threshold", 0.8),
|
||||||
|
round_interval=data.get("round_interval", 1.0),
|
||||||
|
allow_user_interrupt=data.get("allow_user_interrupt", True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoom(Document):
|
||||||
|
"""
|
||||||
|
聊天室文档模型
|
||||||
|
存储讨论聊天室的配置信息
|
||||||
|
"""
|
||||||
|
room_id: str = Field(..., description="唯一标识")
|
||||||
|
name: str = Field(..., description="聊天室名称")
|
||||||
|
description: str = Field(default="", description="描述")
|
||||||
|
objective: str = Field(default="", description="当前讨论目标")
|
||||||
|
|
||||||
|
# 参与者
|
||||||
|
agents: List[str] = Field(default_factory=list, description="Agent ID列表")
|
||||||
|
moderator_agent_id: Optional[str] = Field(default=None, description="共识判断Agent ID")
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
config: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"max_rounds": 50,
|
||||||
|
"message_history_size": 20,
|
||||||
|
"consensus_threshold": 0.8,
|
||||||
|
"round_interval": 1.0,
|
||||||
|
"allow_user_interrupt": True
|
||||||
|
},
|
||||||
|
description="聊天室配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status: str = Field(default=ChatRoomStatus.IDLE.value, description="当前状态")
|
||||||
|
current_round: int = Field(default=0, description="当前轮次")
|
||||||
|
current_discussion_id: Optional[str] = Field(default=None, description="当前讨论ID")
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "chatrooms"
|
||||||
|
|
||||||
|
def get_config(self) -> ChatRoomConfig:
|
||||||
|
"""获取配置对象"""
|
||||||
|
return ChatRoomConfig.from_dict(self.config)
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""检查聊天室是否处于活动状态"""
|
||||||
|
return self.status == ChatRoomStatus.ACTIVE.value
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"room_id": "product-design-room",
|
||||||
|
"name": "产品设计讨论室",
|
||||||
|
"description": "用于讨论新产品功能设计",
|
||||||
|
"objective": "设计一个用户友好的登录系统",
|
||||||
|
"agents": ["product-manager", "designer", "developer"],
|
||||||
|
"moderator_agent_id": "moderator",
|
||||||
|
"config": {
|
||||||
|
"max_rounds": 50,
|
||||||
|
"message_history_size": 20,
|
||||||
|
"consensus_threshold": 0.8
|
||||||
|
},
|
||||||
|
"status": "idle",
|
||||||
|
"current_round": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
126
backend/models/discussion_result.py
Normal file
126
backend/models/discussion_result.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""
|
||||||
|
讨论结果数据模型
|
||||||
|
定义讨论结果的结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class DiscussionResult(Document):
|
||||||
|
"""
|
||||||
|
讨论结果文档模型
|
||||||
|
存储讨论的最终结果
|
||||||
|
"""
|
||||||
|
discussion_id: str = Field(..., description="讨论唯一标识")
|
||||||
|
room_id: str = Field(..., description="聊天室ID")
|
||||||
|
objective: str = Field(..., description="讨论目标")
|
||||||
|
|
||||||
|
# 共识结果
|
||||||
|
consensus_reached: bool = Field(default=False, description="是否达成共识")
|
||||||
|
confidence: float = Field(default=0.0, ge=0, le=1, description="共识置信度")
|
||||||
|
|
||||||
|
# 结果摘要
|
||||||
|
summary: str = Field(default="", description="讨论结果摘要")
|
||||||
|
action_items: List[str] = Field(default_factory=list, description="行动项列表")
|
||||||
|
unresolved_issues: List[str] = Field(default_factory=list, description="未解决的问题")
|
||||||
|
key_decisions: List[str] = Field(default_factory=list, description="关键决策")
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
total_rounds: int = Field(default=0, description="总轮数")
|
||||||
|
total_messages: int = Field(default=0, description="总消息数")
|
||||||
|
participating_agents: List[str] = Field(default_factory=list, description="参与的Agent列表")
|
||||||
|
agent_contributions: Dict[str, int] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="各Agent发言次数统计"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status: str = Field(default="in_progress", description="状态: in_progress, completed, failed")
|
||||||
|
end_reason: str = Field(default="", description="结束原因")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "discussions"
|
||||||
|
indexes = [
|
||||||
|
[("room_id", 1)],
|
||||||
|
[("created_at", -1)],
|
||||||
|
]
|
||||||
|
|
||||||
|
def mark_completed(
|
||||||
|
self,
|
||||||
|
consensus_reached: bool,
|
||||||
|
confidence: float,
|
||||||
|
summary: str,
|
||||||
|
action_items: List[str] = None,
|
||||||
|
unresolved_issues: List[str] = None,
|
||||||
|
end_reason: str = "consensus"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
标记讨论为已完成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
consensus_reached: 是否达成共识
|
||||||
|
confidence: 置信度
|
||||||
|
summary: 结果摘要
|
||||||
|
action_items: 行动项
|
||||||
|
unresolved_issues: 未解决问题
|
||||||
|
end_reason: 结束原因
|
||||||
|
"""
|
||||||
|
self.consensus_reached = consensus_reached
|
||||||
|
self.confidence = confidence
|
||||||
|
self.summary = summary
|
||||||
|
self.action_items = action_items or []
|
||||||
|
self.unresolved_issues = unresolved_issues or []
|
||||||
|
self.status = "completed"
|
||||||
|
self.end_reason = end_reason
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def update_stats(
|
||||||
|
self,
|
||||||
|
total_rounds: int,
|
||||||
|
total_messages: int,
|
||||||
|
agent_contributions: Dict[str, int]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_rounds: 总轮数
|
||||||
|
total_messages: 总消息数
|
||||||
|
agent_contributions: Agent贡献统计
|
||||||
|
"""
|
||||||
|
self.total_rounds = total_rounds
|
||||||
|
self.total_messages = total_messages
|
||||||
|
self.agent_contributions = agent_contributions
|
||||||
|
self.participating_agents = list(agent_contributions.keys())
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"discussion_id": "disc-001",
|
||||||
|
"room_id": "product-design-room",
|
||||||
|
"objective": "设计用户登录系统",
|
||||||
|
"consensus_reached": True,
|
||||||
|
"confidence": 0.85,
|
||||||
|
"summary": "团队一致同意采用OAuth2.0 + 手机验证码的混合认证方案...",
|
||||||
|
"action_items": [
|
||||||
|
"设计OAuth2.0集成方案",
|
||||||
|
"开发短信验证服务",
|
||||||
|
"编写安全测试用例"
|
||||||
|
],
|
||||||
|
"unresolved_issues": [
|
||||||
|
"第三方登录的优先级排序"
|
||||||
|
],
|
||||||
|
"total_rounds": 15,
|
||||||
|
"total_messages": 45,
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
}
|
||||||
123
backend/models/message.py
Normal file
123
backend/models/message.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
消息数据模型
|
||||||
|
定义聊天消息的结构
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from beanie import Document
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
"""消息类型枚举"""
|
||||||
|
TEXT = "text" # 纯文本
|
||||||
|
IMAGE = "image" # 图片
|
||||||
|
FILE = "file" # 文件
|
||||||
|
SYSTEM = "system" # 系统消息
|
||||||
|
ACTION = "action" # 动作消息(如调用工具)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageAttachment:
|
||||||
|
"""消息附件"""
|
||||||
|
attachment_type: str # 附件类型: image, file
|
||||||
|
url: str # 资源URL
|
||||||
|
name: str # 文件名
|
||||||
|
size: int = 0 # 文件大小(字节)
|
||||||
|
mime_type: str = "" # MIME类型
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attachment_type: str,
|
||||||
|
url: str,
|
||||||
|
name: str,
|
||||||
|
size: int = 0,
|
||||||
|
mime_type: str = ""
|
||||||
|
):
|
||||||
|
self.attachment_type = attachment_type
|
||||||
|
self.url = url
|
||||||
|
self.name = name
|
||||||
|
self.size = size
|
||||||
|
self.mime_type = mime_type
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"attachment_type": self.attachment_type,
|
||||||
|
"url": self.url,
|
||||||
|
"name": self.name,
|
||||||
|
"size": self.size,
|
||||||
|
"mime_type": self.mime_type
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "MessageAttachment":
|
||||||
|
"""从字典创建"""
|
||||||
|
return cls(
|
||||||
|
attachment_type=data.get("attachment_type", ""),
|
||||||
|
url=data.get("url", ""),
|
||||||
|
name=data.get("name", ""),
|
||||||
|
size=data.get("size", 0),
|
||||||
|
mime_type=data.get("mime_type", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Document):
|
||||||
|
"""
|
||||||
|
消息文档模型
|
||||||
|
存储聊天消息
|
||||||
|
"""
|
||||||
|
message_id: str = Field(..., description="唯一标识")
|
||||||
|
room_id: str = Field(..., description="聊天室ID")
|
||||||
|
discussion_id: str = Field(..., description="讨论ID")
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="发送Agent ID(系统消息为空)")
|
||||||
|
|
||||||
|
# 消息内容
|
||||||
|
content: str = Field(..., description="消息内容")
|
||||||
|
message_type: str = Field(default=MessageType.TEXT.value, description="消息类型")
|
||||||
|
attachments: List[Dict[str, Any]] = Field(default_factory=list, description="附件列表")
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
round: int = Field(default=0, description="所属轮次")
|
||||||
|
token_count: int = Field(default=0, description="token数量")
|
||||||
|
|
||||||
|
# 工具调用相关
|
||||||
|
tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用记录")
|
||||||
|
tool_results: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用结果")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "messages"
|
||||||
|
indexes = [
|
||||||
|
[("room_id", 1), ("created_at", 1)],
|
||||||
|
[("discussion_id", 1)],
|
||||||
|
[("agent_id", 1)],
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_attachments(self) -> List[MessageAttachment]:
|
||||||
|
"""获取附件对象列表"""
|
||||||
|
return [MessageAttachment.from_dict(a) for a in self.attachments]
|
||||||
|
|
||||||
|
def is_from_agent(self, agent_id: str) -> bool:
|
||||||
|
"""检查消息是否来自指定Agent"""
|
||||||
|
return self.agent_id == agent_id
|
||||||
|
|
||||||
|
def is_system_message(self) -> bool:
|
||||||
|
"""检查是否为系统消息"""
|
||||||
|
return self.message_type == MessageType.SYSTEM.value
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"message_id": "msg-001",
|
||||||
|
"room_id": "product-design-room",
|
||||||
|
"discussion_id": "disc-001",
|
||||||
|
"agent_id": "product-manager",
|
||||||
|
"content": "我认为登录系统应该支持多种认证方式...",
|
||||||
|
"message_type": "text",
|
||||||
|
"round": 1,
|
||||||
|
"token_count": 150
|
||||||
|
}
|
||||||
|
}
|
||||||
42
backend/requirements.txt
Normal file
42
backend/requirements.txt
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# FastAPI and server
|
||||||
|
fastapi==0.109.0
|
||||||
|
uvicorn[standard]==0.27.0
|
||||||
|
python-multipart==0.0.6
|
||||||
|
websockets==12.0
|
||||||
|
|
||||||
|
# MongoDB
|
||||||
|
motor==3.3.2
|
||||||
|
pymongo==4.6.1
|
||||||
|
beanie==1.25.0
|
||||||
|
|
||||||
|
# HTTP client
|
||||||
|
httpx==0.26.0
|
||||||
|
aiohttp==3.9.1
|
||||||
|
|
||||||
|
# AI SDK clients
|
||||||
|
openai==1.12.0
|
||||||
|
google-generativeai==0.3.2
|
||||||
|
zhipuai==2.0.1
|
||||||
|
|
||||||
|
# Data validation
|
||||||
|
pydantic==2.6.0
|
||||||
|
pydantic-settings==2.1.0
|
||||||
|
|
||||||
|
# Security
|
||||||
|
cryptography==42.0.2
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
|
passlib[bcrypt]==1.7.4
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
tenacity==8.2.3
|
||||||
|
numpy==1.26.4
|
||||||
|
|
||||||
|
# For embeddings and vector search
|
||||||
|
sentence-transformers==2.3.1
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
slowapi==0.1.9
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
loguru==0.7.2
|
||||||
14
backend/routers/__init__.py
Normal file
14
backend/routers/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
API路由模块
|
||||||
|
"""
|
||||||
|
from . import providers
|
||||||
|
from . import agents
|
||||||
|
from . import chatrooms
|
||||||
|
from . import discussions
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"providers",
|
||||||
|
"agents",
|
||||||
|
"chatrooms",
|
||||||
|
"discussions",
|
||||||
|
]
|
||||||
314
backend/routers/agents.py
Normal file
314
backend/routers/agents.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
"""
|
||||||
|
Agent管理路由
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.agent_service import AgentService, AGENT_TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
class CapabilitiesModel(BaseModel):
|
||||||
|
"""能力配置模型"""
|
||||||
|
memory_enabled: bool = False
|
||||||
|
mcp_tools: List[str] = []
|
||||||
|
skills: List[str] = []
|
||||||
|
multimodal: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class BehaviorModel(BaseModel):
|
||||||
|
"""行为配置模型"""
|
||||||
|
speak_threshold: float = 0.5
|
||||||
|
max_speak_per_round: int = 2
|
||||||
|
speak_style: str = "balanced"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCreateRequest(BaseModel):
|
||||||
|
"""创建Agent请求"""
|
||||||
|
name: str = Field(..., description="Agent名称")
|
||||||
|
role: str = Field(..., description="角色定义")
|
||||||
|
system_prompt: str = Field(..., description="系统提示词")
|
||||||
|
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||||
|
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
|
||||||
|
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
|
||||||
|
capabilities: Optional[CapabilitiesModel] = None
|
||||||
|
behavior: Optional[BehaviorModel] = None
|
||||||
|
avatar: Optional[str] = None
|
||||||
|
color: str = "#1890ff"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "产品经理",
|
||||||
|
"role": "产品规划和需求分析专家",
|
||||||
|
"system_prompt": "你是一位经验丰富的产品经理...",
|
||||||
|
"provider_id": "openrouter-abc123",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 2000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentUpdateRequest(BaseModel):
|
||||||
|
"""更新Agent请求"""
|
||||||
|
name: Optional[str] = None
|
||||||
|
role: Optional[str] = None
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
temperature: Optional[float] = Field(default=None, ge=0, le=2)
|
||||||
|
max_tokens: Optional[int] = Field(default=None, gt=0)
|
||||||
|
capabilities: Optional[CapabilitiesModel] = None
|
||||||
|
behavior: Optional[BehaviorModel] = None
|
||||||
|
avatar: Optional[str] = None
|
||||||
|
color: Optional[str] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
"""Agent响应"""
|
||||||
|
agent_id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
system_prompt: str
|
||||||
|
provider_id: str
|
||||||
|
temperature: float
|
||||||
|
max_tokens: int
|
||||||
|
capabilities: Dict[str, Any]
|
||||||
|
behavior: Dict[str, Any]
|
||||||
|
avatar: Optional[str]
|
||||||
|
color: str
|
||||||
|
enabled: bool
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTestRequest(BaseModel):
|
||||||
|
"""Agent测试请求"""
|
||||||
|
message: str = "你好,请简单介绍一下你自己。"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTestResponse(BaseModel):
|
||||||
|
"""Agent测试响应"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
response: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
tokens: Optional[int] = None
|
||||||
|
latency_ms: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateResponse(BaseModel):
|
||||||
|
"""模板响应"""
|
||||||
|
template_id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
system_prompt: str
|
||||||
|
color: str
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratePromptRequest(BaseModel):
|
||||||
|
"""生成提示词请求"""
|
||||||
|
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||||
|
name: str = Field(..., description="Agent名称")
|
||||||
|
role: str = Field(..., description="角色定位")
|
||||||
|
description: Optional[str] = Field(None, description="额外描述(可选)")
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratePromptResponse(BaseModel):
|
||||||
|
"""生成提示词响应"""
|
||||||
|
success: bool
|
||||||
|
message: Optional[str] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 路由处理 ============
|
||||||
|
|
||||||
|
@router.post("", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_agent(request: AgentCreateRequest):
|
||||||
|
"""
|
||||||
|
创建新的Agent
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await AgentService.create_agent(
|
||||||
|
name=request.name,
|
||||||
|
role=request.role,
|
||||||
|
system_prompt=request.system_prompt,
|
||||||
|
provider_id=request.provider_id,
|
||||||
|
temperature=request.temperature,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
capabilities=request.capabilities.dict() if request.capabilities else None,
|
||||||
|
behavior=request.behavior.dict() if request.behavior else None,
|
||||||
|
avatar=request.avatar,
|
||||||
|
color=request.color
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_response(agent)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建Agent失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="创建失败")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[AgentResponse])
|
||||||
|
async def list_agents(enabled_only: bool = False):
|
||||||
|
"""
|
||||||
|
获取所有Agent
|
||||||
|
"""
|
||||||
|
agents = await AgentService.get_all_agents(enabled_only)
|
||||||
|
return [_to_response(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates", response_model=List[TemplateResponse])
|
||||||
|
async def list_templates():
|
||||||
|
"""
|
||||||
|
获取Agent预设模板
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
TemplateResponse(
|
||||||
|
template_id=tid,
|
||||||
|
name=t["name"],
|
||||||
|
role=t["role"],
|
||||||
|
system_prompt=t["system_prompt"],
|
||||||
|
color=t["color"]
|
||||||
|
)
|
||||||
|
for tid, t in AGENT_TEMPLATES.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||||
|
async def generate_prompt(request: GeneratePromptRequest):
|
||||||
|
"""
|
||||||
|
使用AI生成Agent系统提示词
|
||||||
|
"""
|
||||||
|
result = await AgentService.generate_system_prompt(
|
||||||
|
provider_id=request.provider_id,
|
||||||
|
name=request.name,
|
||||||
|
role=request.role,
|
||||||
|
description=request.description
|
||||||
|
)
|
||||||
|
return GeneratePromptResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{agent_id}", response_model=AgentResponse)
|
||||||
|
async def get_agent(agent_id: str):
|
||||||
|
"""
|
||||||
|
获取指定Agent
|
||||||
|
"""
|
||||||
|
agent = await AgentService.get_agent(agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||||
|
return _to_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{agent_id}", response_model=AgentResponse)
|
||||||
|
async def update_agent(agent_id: str, request: AgentUpdateRequest):
|
||||||
|
"""
|
||||||
|
更新Agent配置
|
||||||
|
"""
|
||||||
|
update_data = request.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
# 转换嵌套模型
|
||||||
|
if "capabilities" in update_data and update_data["capabilities"]:
|
||||||
|
if hasattr(update_data["capabilities"], "dict"):
|
||||||
|
update_data["capabilities"] = update_data["capabilities"].dict()
|
||||||
|
if "behavior" in update_data and update_data["behavior"]:
|
||||||
|
if hasattr(update_data["behavior"], "dict"):
|
||||||
|
update_data["behavior"] = update_data["behavior"].dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await AgentService.update_agent(agent_id, **update_data)
|
||||||
|
if not agent:
|
||||||
|
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||||
|
return _to_response(agent)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{agent_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_agent(agent_id: str):
|
||||||
|
"""
|
||||||
|
删除Agent
|
||||||
|
"""
|
||||||
|
success = await AgentService.delete_agent(agent_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/test", response_model=AgentTestResponse)
|
||||||
|
async def test_agent(agent_id: str, request: AgentTestRequest = None):
|
||||||
|
"""
|
||||||
|
测试Agent对话
|
||||||
|
"""
|
||||||
|
message = request.message if request else "你好,请简单介绍一下你自己。"
|
||||||
|
result = await AgentService.test_agent(agent_id, message)
|
||||||
|
return AgentTestResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/duplicate", response_model=AgentResponse)
|
||||||
|
async def duplicate_agent(agent_id: str, new_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
复制Agent
|
||||||
|
"""
|
||||||
|
agent = await AgentService.duplicate_agent(agent_id, new_name)
|
||||||
|
if not agent:
|
||||||
|
raise HTTPException(status_code=404, detail="源Agent不存在")
|
||||||
|
return _to_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/from-template/{template_id}", response_model=AgentResponse)
|
||||||
|
async def create_from_template(template_id: str, provider_id: str):
|
||||||
|
"""
|
||||||
|
从模板创建Agent
|
||||||
|
"""
|
||||||
|
if template_id not in AGENT_TEMPLATES:
|
||||||
|
raise HTTPException(status_code=404, detail="模板不存在")
|
||||||
|
|
||||||
|
template = AGENT_TEMPLATES[template_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await AgentService.create_agent(
|
||||||
|
name=template["name"],
|
||||||
|
role=template["role"],
|
||||||
|
system_prompt=template["system_prompt"],
|
||||||
|
provider_id=provider_id,
|
||||||
|
color=template["color"]
|
||||||
|
)
|
||||||
|
return _to_response(agent)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 辅助函数 ============
|
||||||
|
|
||||||
|
def _to_response(agent) -> AgentResponse:
|
||||||
|
"""
|
||||||
|
转换为响应模型
|
||||||
|
"""
|
||||||
|
return AgentResponse(
|
||||||
|
agent_id=agent.agent_id,
|
||||||
|
name=agent.name,
|
||||||
|
role=agent.role,
|
||||||
|
system_prompt=agent.system_prompt,
|
||||||
|
provider_id=agent.provider_id,
|
||||||
|
temperature=agent.temperature,
|
||||||
|
max_tokens=agent.max_tokens,
|
||||||
|
capabilities=agent.capabilities,
|
||||||
|
behavior=agent.behavior,
|
||||||
|
avatar=agent.avatar,
|
||||||
|
color=agent.color,
|
||||||
|
enabled=agent.enabled,
|
||||||
|
created_at=agent.created_at.isoformat(),
|
||||||
|
updated_at=agent.updated_at.isoformat()
|
||||||
|
)
|
||||||
387
backend/routers/chatrooms.py
Normal file
387
backend/routers/chatrooms.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
"""
|
||||||
|
聊天室管理路由
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.chatroom_service import ChatRoomService
|
||||||
|
from services.discussion_engine import DiscussionEngine
|
||||||
|
from services.message_router import MessageRouter
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
class ChatRoomConfigModel(BaseModel):
|
||||||
|
"""聊天室配置模型"""
|
||||||
|
max_rounds: int = 50
|
||||||
|
message_history_size: int = 20
|
||||||
|
consensus_threshold: float = 0.8
|
||||||
|
round_interval: float = 1.0
|
||||||
|
allow_user_interrupt: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomCreateRequest(BaseModel):
|
||||||
|
"""创建聊天室请求"""
|
||||||
|
name: str = Field(..., description="聊天室名称")
|
||||||
|
description: str = Field(default="", description="描述")
|
||||||
|
agents: List[str] = Field(default=[], description="Agent ID列表")
|
||||||
|
moderator_agent_id: Optional[str] = Field(default=None, description="主持人Agent ID")
|
||||||
|
config: Optional[ChatRoomConfigModel] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "产品设计讨论室",
|
||||||
|
"description": "用于讨论新产品功能设计",
|
||||||
|
"agents": ["agent-abc123", "agent-def456"],
|
||||||
|
"moderator_agent_id": "agent-xyz789"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomUpdateRequest(BaseModel):
|
||||||
|
"""更新聊天室请求"""
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
agents: Optional[List[str]] = None
|
||||||
|
moderator_agent_id: Optional[str] = None
|
||||||
|
config: Optional[ChatRoomConfigModel] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomResponse(BaseModel):
|
||||||
|
"""聊天室响应"""
|
||||||
|
room_id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
objective: str
|
||||||
|
agents: List[str]
|
||||||
|
moderator_agent_id: Optional[str]
|
||||||
|
config: Dict[str, Any]
|
||||||
|
status: str
|
||||||
|
current_round: int
|
||||||
|
current_discussion_id: Optional[str]
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
completed_at: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
"""消息响应"""
|
||||||
|
message_id: str
|
||||||
|
room_id: str
|
||||||
|
discussion_id: str
|
||||||
|
agent_id: Optional[str]
|
||||||
|
content: str
|
||||||
|
message_type: str
|
||||||
|
round: int
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class StartDiscussionRequest(BaseModel):
|
||||||
|
"""启动讨论请求"""
|
||||||
|
objective: str = Field(..., description="讨论目标")
|
||||||
|
|
||||||
|
|
||||||
|
class DiscussionStatusResponse(BaseModel):
|
||||||
|
"""讨论状态响应"""
|
||||||
|
is_active: bool
|
||||||
|
room_id: str
|
||||||
|
discussion_id: Optional[str] = None
|
||||||
|
current_round: int = 0
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 路由处理 ============
|
||||||
|
|
||||||
|
@router.post("", response_model=ChatRoomResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_chatroom(request: ChatRoomCreateRequest):
|
||||||
|
"""
|
||||||
|
创建新的聊天室
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
chatroom = await ChatRoomService.create_chatroom(
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
agents=request.agents,
|
||||||
|
moderator_agent_id=request.moderator_agent_id,
|
||||||
|
config=request.config.dict() if request.config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_response(chatroom)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建聊天室失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="创建失败")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[ChatRoomResponse])
|
||||||
|
async def list_chatrooms():
|
||||||
|
"""
|
||||||
|
获取所有聊天室
|
||||||
|
"""
|
||||||
|
chatrooms = await ChatRoomService.get_all_chatrooms()
|
||||||
|
return [_to_response(c) for c in chatrooms]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{room_id}", response_model=ChatRoomResponse)
|
||||||
|
async def get_chatroom(room_id: str):
|
||||||
|
"""
|
||||||
|
获取指定聊天室
|
||||||
|
"""
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
return _to_response(chatroom)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{room_id}", response_model=ChatRoomResponse)
|
||||||
|
async def update_chatroom(room_id: str, request: ChatRoomUpdateRequest):
|
||||||
|
"""
|
||||||
|
更新聊天室配置
|
||||||
|
"""
|
||||||
|
update_data = request.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
if "config" in update_data and update_data["config"]:
|
||||||
|
if hasattr(update_data["config"], "dict"):
|
||||||
|
update_data["config"] = update_data["config"].dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
chatroom = await ChatRoomService.update_chatroom(room_id, **update_data)
|
||||||
|
if not chatroom:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
return _to_response(chatroom)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{room_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_chatroom(room_id: str):
|
||||||
|
"""
|
||||||
|
删除聊天室
|
||||||
|
"""
|
||||||
|
success = await ChatRoomService.delete_chatroom(room_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
|
||||||
|
async def add_agent_to_chatroom(room_id: str, agent_id: str):
|
||||||
|
"""
|
||||||
|
向聊天室添加Agent
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
chatroom = await ChatRoomService.add_agent(room_id, agent_id)
|
||||||
|
if not chatroom:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
return _to_response(chatroom)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
|
||||||
|
async def remove_agent_from_chatroom(room_id: str, agent_id: str):
|
||||||
|
"""
|
||||||
|
从聊天室移除Agent
|
||||||
|
"""
|
||||||
|
chatroom = await ChatRoomService.remove_agent(room_id, agent_id)
|
||||||
|
if not chatroom:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
return _to_response(chatroom)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{room_id}/messages", response_model=List[MessageResponse])
|
||||||
|
async def get_chatroom_messages(
|
||||||
|
room_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
skip: int = 0,
|
||||||
|
discussion_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取聊天室消息历史
|
||||||
|
"""
|
||||||
|
messages = await ChatRoomService.get_messages(
|
||||||
|
room_id, limit, skip, discussion_id
|
||||||
|
)
|
||||||
|
return [_message_to_response(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{room_id}/start", response_model=DiscussionStatusResponse)
|
||||||
|
async def start_discussion(room_id: str, request: StartDiscussionRequest):
|
||||||
|
"""
|
||||||
|
启动讨论
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 异步启动讨论(不等待完成)
|
||||||
|
import asyncio
|
||||||
|
asyncio.create_task(
|
||||||
|
DiscussionEngine.start_discussion(room_id, request.objective)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待一小段时间让讨论初始化
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
|
||||||
|
return DiscussionStatusResponse(
|
||||||
|
is_active=True,
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||||
|
current_round=chatroom.current_round if chatroom else 0,
|
||||||
|
status=chatroom.status if chatroom else "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{room_id}/pause", response_model=DiscussionStatusResponse)
|
||||||
|
async def pause_discussion(room_id: str):
|
||||||
|
"""
|
||||||
|
暂停讨论
|
||||||
|
"""
|
||||||
|
success = await DiscussionEngine.pause_discussion(room_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="没有进行中的讨论")
|
||||||
|
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
return DiscussionStatusResponse(
|
||||||
|
is_active=False,
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||||
|
current_round=chatroom.current_round if chatroom else 0,
|
||||||
|
status="paused"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{room_id}/resume", response_model=DiscussionStatusResponse)
|
||||||
|
async def resume_discussion(room_id: str):
|
||||||
|
"""
|
||||||
|
恢复讨论
|
||||||
|
"""
|
||||||
|
success = await DiscussionEngine.resume_discussion(room_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="聊天室不在暂停状态")
|
||||||
|
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
return DiscussionStatusResponse(
|
||||||
|
is_active=True,
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||||
|
current_round=chatroom.current_round if chatroom else 0,
|
||||||
|
status="active"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{room_id}/stop", response_model=DiscussionStatusResponse)
|
||||||
|
async def stop_discussion(room_id: str):
|
||||||
|
"""
|
||||||
|
停止讨论
|
||||||
|
"""
|
||||||
|
success = await DiscussionEngine.stop_discussion(room_id)
|
||||||
|
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
return DiscussionStatusResponse(
|
||||||
|
is_active=False,
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||||
|
current_round=chatroom.current_round if chatroom else 0,
|
||||||
|
status="stopping" if success else chatroom.status if chatroom else "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{room_id}/status", response_model=DiscussionStatusResponse)
|
||||||
|
async def get_discussion_status(room_id: str):
|
||||||
|
"""
|
||||||
|
获取讨论状态
|
||||||
|
"""
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||||
|
|
||||||
|
is_active = DiscussionEngine.is_discussion_active(room_id)
|
||||||
|
|
||||||
|
return DiscussionStatusResponse(
|
||||||
|
is_active=is_active,
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=chatroom.current_discussion_id,
|
||||||
|
current_round=chatroom.current_round,
|
||||||
|
status=chatroom.status
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ WebSocket端点 ============
|
||||||
|
|
||||||
|
@router.websocket("/ws/{room_id}")
|
||||||
|
async def chatroom_websocket(websocket: WebSocket, room_id: str):
|
||||||
|
"""
|
||||||
|
聊天室WebSocket连接
|
||||||
|
"""
|
||||||
|
# 验证聊天室存在
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
await websocket.close(code=4004, reason="聊天室不存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
await MessageRouter.connect(room_id, websocket)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 保持连接,接收客户端消息(如心跳)
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
|
||||||
|
# 处理心跳
|
||||||
|
if data == "ping":
|
||||||
|
await websocket.send_text("pong")
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
await MessageRouter.disconnect(room_id, websocket)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket错误: {e}")
|
||||||
|
await MessageRouter.disconnect(room_id, websocket)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 辅助函数 ============
|
||||||
|
|
||||||
|
def _to_response(chatroom) -> ChatRoomResponse:
|
||||||
|
"""
|
||||||
|
转换为响应模型
|
||||||
|
"""
|
||||||
|
return ChatRoomResponse(
|
||||||
|
room_id=chatroom.room_id,
|
||||||
|
name=chatroom.name,
|
||||||
|
description=chatroom.description,
|
||||||
|
objective=chatroom.objective,
|
||||||
|
agents=chatroom.agents,
|
||||||
|
moderator_agent_id=chatroom.moderator_agent_id,
|
||||||
|
config=chatroom.config,
|
||||||
|
status=chatroom.status,
|
||||||
|
current_round=chatroom.current_round,
|
||||||
|
current_discussion_id=chatroom.current_discussion_id,
|
||||||
|
created_at=chatroom.created_at.isoformat(),
|
||||||
|
updated_at=chatroom.updated_at.isoformat(),
|
||||||
|
completed_at=chatroom.completed_at.isoformat() if chatroom.completed_at else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_response(message) -> MessageResponse:
|
||||||
|
"""
|
||||||
|
转换消息为响应模型
|
||||||
|
"""
|
||||||
|
return MessageResponse(
|
||||||
|
message_id=message.message_id,
|
||||||
|
room_id=message.room_id,
|
||||||
|
discussion_id=message.discussion_id,
|
||||||
|
agent_id=message.agent_id,
|
||||||
|
content=message.content,
|
||||||
|
message_type=message.message_type,
|
||||||
|
round=message.round,
|
||||||
|
created_at=message.created_at.isoformat()
|
||||||
|
)
|
||||||
136
backend/routers/discussions.py
Normal file
136
backend/routers/discussions.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
讨论结果路由
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.discussion_result import DiscussionResult
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 响应模型 ============
|
||||||
|
|
||||||
|
class DiscussionResponse(BaseModel):
|
||||||
|
"""讨论结果响应"""
|
||||||
|
discussion_id: str
|
||||||
|
room_id: str
|
||||||
|
objective: str
|
||||||
|
consensus_reached: bool
|
||||||
|
confidence: float
|
||||||
|
summary: str
|
||||||
|
action_items: List[str]
|
||||||
|
unresolved_issues: List[str]
|
||||||
|
key_decisions: List[str]
|
||||||
|
total_rounds: int
|
||||||
|
total_messages: int
|
||||||
|
participating_agents: List[str]
|
||||||
|
agent_contributions: Dict[str, int]
|
||||||
|
status: str
|
||||||
|
end_reason: str
|
||||||
|
created_at: str
|
||||||
|
completed_at: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscussionListResponse(BaseModel):
|
||||||
|
"""讨论列表响应"""
|
||||||
|
discussions: List[DiscussionResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 路由处理 ============
|
||||||
|
|
||||||
|
@router.get("", response_model=DiscussionListResponse)
|
||||||
|
async def list_discussions(
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
skip: int = 0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取讨论结果列表
|
||||||
|
"""
|
||||||
|
query = {}
|
||||||
|
if room_id:
|
||||||
|
query["room_id"] = room_id
|
||||||
|
|
||||||
|
discussions = await DiscussionResult.find(query).sort(
|
||||||
|
"-created_at"
|
||||||
|
).skip(skip).limit(limit).to_list()
|
||||||
|
|
||||||
|
total = await DiscussionResult.find(query).count()
|
||||||
|
|
||||||
|
return DiscussionListResponse(
|
||||||
|
discussions=[_to_response(d) for d in discussions],
|
||||||
|
total=total
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{discussion_id}", response_model=DiscussionResponse)
|
||||||
|
async def get_discussion(discussion_id: str):
|
||||||
|
"""
|
||||||
|
获取指定讨论结果
|
||||||
|
"""
|
||||||
|
discussion = await DiscussionResult.find_one(
|
||||||
|
DiscussionResult.discussion_id == discussion_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not discussion:
|
||||||
|
raise HTTPException(status_code=404, detail="讨论记录不存在")
|
||||||
|
|
||||||
|
return _to_response(discussion)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/room/{room_id}", response_model=List[DiscussionResponse])
|
||||||
|
async def get_room_discussions(room_id: str, limit: int = 10):
|
||||||
|
"""
|
||||||
|
获取聊天室的讨论历史
|
||||||
|
"""
|
||||||
|
discussions = await DiscussionResult.find(
|
||||||
|
{"room_id": room_id}
|
||||||
|
).sort("-created_at").limit(limit).to_list()
|
||||||
|
|
||||||
|
return [_to_response(d) for d in discussions]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/room/{room_id}/latest", response_model=DiscussionResponse)
|
||||||
|
async def get_latest_discussion(room_id: str):
|
||||||
|
"""
|
||||||
|
获取聊天室最新的讨论结果
|
||||||
|
"""
|
||||||
|
discussion = await DiscussionResult.find(
|
||||||
|
{"room_id": room_id}
|
||||||
|
).sort("-created_at").first_or_none()
|
||||||
|
|
||||||
|
if not discussion:
|
||||||
|
raise HTTPException(status_code=404, detail="没有找到讨论记录")
|
||||||
|
|
||||||
|
return _to_response(discussion)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 辅助函数 ============
|
||||||
|
|
||||||
|
def _to_response(discussion: DiscussionResult) -> DiscussionResponse:
|
||||||
|
"""
|
||||||
|
转换为响应模型
|
||||||
|
"""
|
||||||
|
return DiscussionResponse(
|
||||||
|
discussion_id=discussion.discussion_id,
|
||||||
|
room_id=discussion.room_id,
|
||||||
|
objective=discussion.objective,
|
||||||
|
consensus_reached=discussion.consensus_reached,
|
||||||
|
confidence=discussion.confidence,
|
||||||
|
summary=discussion.summary,
|
||||||
|
action_items=discussion.action_items,
|
||||||
|
unresolved_issues=discussion.unresolved_issues,
|
||||||
|
key_decisions=discussion.key_decisions,
|
||||||
|
total_rounds=discussion.total_rounds,
|
||||||
|
total_messages=discussion.total_messages,
|
||||||
|
participating_agents=discussion.participating_agents,
|
||||||
|
agent_contributions=discussion.agent_contributions,
|
||||||
|
status=discussion.status,
|
||||||
|
end_reason=discussion.end_reason,
|
||||||
|
created_at=discussion.created_at.isoformat(),
|
||||||
|
completed_at=discussion.completed_at.isoformat() if discussion.completed_at else None
|
||||||
|
)
|
||||||
241
backend/routers/providers.py
Normal file
241
backend/routers/providers.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""
|
||||||
|
AI接口管理路由
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.ai_provider_service import AIProviderService
|
||||||
|
from utils.encryption import mask_api_key
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
class ProxyConfigModel(BaseModel):
|
||||||
|
"""代理配置模型"""
|
||||||
|
http_proxy: Optional[str] = None
|
||||||
|
https_proxy: Optional[str] = None
|
||||||
|
no_proxy: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitModel(BaseModel):
|
||||||
|
"""速率限制模型"""
|
||||||
|
requests_per_minute: int = 60
|
||||||
|
tokens_per_minute: int = 100000
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCreateRequest(BaseModel):
|
||||||
|
"""创建AI接口请求"""
|
||||||
|
provider_type: str = Field(..., description="提供商类型: minimax, zhipu, openrouter, kimi, deepseek, gemini, ollama, llmstudio")
|
||||||
|
name: str = Field(..., description="自定义名称")
|
||||||
|
model: str = Field(..., description="模型名称")
|
||||||
|
api_key: str = Field(default="", description="API密钥")
|
||||||
|
base_url: str = Field(default="", description="API基础URL")
|
||||||
|
use_proxy: bool = Field(default=False, description="是否使用代理")
|
||||||
|
proxy_config: Optional[ProxyConfigModel] = None
|
||||||
|
rate_limit: Optional[RateLimitModel] = None
|
||||||
|
timeout: int = Field(default=60, description="超时时间(秒)")
|
||||||
|
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"provider_type": "openrouter",
|
||||||
|
"name": "OpenRouter GPT-4",
|
||||||
|
"model": "openai/gpt-4-turbo",
|
||||||
|
"api_key": "sk-xxx",
|
||||||
|
"use_proxy": True,
|
||||||
|
"proxy_config": {
|
||||||
|
"http_proxy": "http://127.0.0.1:7890",
|
||||||
|
"https_proxy": "http://127.0.0.1:7890"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderUpdateRequest(BaseModel):
|
||||||
|
"""更新AI接口请求"""
|
||||||
|
name: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
use_proxy: Optional[bool] = None
|
||||||
|
proxy_config: Optional[ProxyConfigModel] = None
|
||||||
|
rate_limit: Optional[RateLimitModel] = None
|
||||||
|
timeout: Optional[int] = None
|
||||||
|
extra_params: Optional[Dict[str, Any]] = None
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderResponse(BaseModel):
|
||||||
|
"""AI接口响应"""
|
||||||
|
provider_id: str
|
||||||
|
provider_type: str
|
||||||
|
name: str
|
||||||
|
api_key_masked: str
|
||||||
|
base_url: str
|
||||||
|
model: str
|
||||||
|
use_proxy: bool
|
||||||
|
proxy_config: Dict[str, Any]
|
||||||
|
rate_limit: Dict[str, int]
|
||||||
|
timeout: int
|
||||||
|
extra_params: Dict[str, Any]
|
||||||
|
enabled: bool
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigRequest(BaseModel):
|
||||||
|
"""测试配置请求"""
|
||||||
|
provider_type: str
|
||||||
|
api_key: str
|
||||||
|
base_url: str = ""
|
||||||
|
model: str = ""
|
||||||
|
use_proxy: bool = False
|
||||||
|
proxy_config: Optional[ProxyConfigModel] = None
|
||||||
|
timeout: int = 30
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponse(BaseModel):
|
||||||
|
"""测试响应"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
model: Optional[str] = None
|
||||||
|
latency_ms: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 路由处理 ============
|
||||||
|
|
||||||
|
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_provider(request: ProviderCreateRequest):
|
||||||
|
"""
|
||||||
|
创建新的AI接口配置
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider = await AIProviderService.create_provider(
|
||||||
|
provider_type=request.provider_type,
|
||||||
|
name=request.name,
|
||||||
|
model=request.model,
|
||||||
|
api_key=request.api_key,
|
||||||
|
base_url=request.base_url,
|
||||||
|
use_proxy=request.use_proxy,
|
||||||
|
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
||||||
|
rate_limit=request.rate_limit.dict() if request.rate_limit else None,
|
||||||
|
timeout=request.timeout,
|
||||||
|
extra_params=request.extra_params
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_response(provider)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建AI接口失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="创建失败")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[ProviderResponse])
|
||||||
|
async def list_providers(enabled_only: bool = False):
|
||||||
|
"""
|
||||||
|
获取所有AI接口配置
|
||||||
|
"""
|
||||||
|
providers = await AIProviderService.get_all_providers(enabled_only)
|
||||||
|
return [_to_response(p) for p in providers]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{provider_id}", response_model=ProviderResponse)
|
||||||
|
async def get_provider(provider_id: str):
|
||||||
|
"""
|
||||||
|
获取指定AI接口配置
|
||||||
|
"""
|
||||||
|
provider = await AIProviderService.get_provider(provider_id)
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||||
|
return _to_response(provider)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{provider_id}", response_model=ProviderResponse)
|
||||||
|
async def update_provider(provider_id: str, request: ProviderUpdateRequest):
|
||||||
|
"""
|
||||||
|
更新AI接口配置
|
||||||
|
"""
|
||||||
|
update_data = request.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
# 转换嵌套模型
|
||||||
|
if "proxy_config" in update_data and update_data["proxy_config"]:
|
||||||
|
update_data["proxy_config"] = update_data["proxy_config"].dict() if hasattr(update_data["proxy_config"], "dict") else update_data["proxy_config"]
|
||||||
|
if "rate_limit" in update_data and update_data["rate_limit"]:
|
||||||
|
update_data["rate_limit"] = update_data["rate_limit"].dict() if hasattr(update_data["rate_limit"], "dict") else update_data["rate_limit"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await AIProviderService.update_provider(provider_id, **update_data)
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||||
|
return _to_response(provider)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_provider(provider_id: str):
|
||||||
|
"""
|
||||||
|
删除AI接口配置
|
||||||
|
"""
|
||||||
|
success = await AIProviderService.delete_provider(provider_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{provider_id}/test", response_model=TestResponse)
|
||||||
|
async def test_provider(provider_id: str):
|
||||||
|
"""
|
||||||
|
测试AI接口连接
|
||||||
|
"""
|
||||||
|
result = await AIProviderService.test_provider(provider_id)
|
||||||
|
return TestResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test", response_model=TestResponse)
|
||||||
|
async def test_provider_config(request: TestConfigRequest):
|
||||||
|
"""
|
||||||
|
测试AI接口配置(不保存)
|
||||||
|
"""
|
||||||
|
result = await AIProviderService.test_provider_config(
|
||||||
|
provider_type=request.provider_type,
|
||||||
|
api_key=request.api_key,
|
||||||
|
base_url=request.base_url,
|
||||||
|
model=request.model,
|
||||||
|
use_proxy=request.use_proxy,
|
||||||
|
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
||||||
|
timeout=request.timeout
|
||||||
|
)
|
||||||
|
return TestResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 辅助函数 ============
|
||||||
|
|
||||||
|
def _to_response(provider) -> ProviderResponse:
|
||||||
|
"""
|
||||||
|
转换为响应模型
|
||||||
|
"""
|
||||||
|
return ProviderResponse(
|
||||||
|
provider_id=provider.provider_id,
|
||||||
|
provider_type=provider.provider_type,
|
||||||
|
name=provider.name,
|
||||||
|
api_key_masked=mask_api_key(provider.api_key) if provider.api_key else "",
|
||||||
|
base_url=provider.base_url,
|
||||||
|
model=provider.model,
|
||||||
|
use_proxy=provider.use_proxy,
|
||||||
|
proxy_config=provider.proxy_config,
|
||||||
|
rate_limit=provider.rate_limit,
|
||||||
|
timeout=provider.timeout,
|
||||||
|
extra_params=provider.extra_params,
|
||||||
|
enabled=provider.enabled,
|
||||||
|
created_at=provider.created_at.isoformat(),
|
||||||
|
updated_at=provider.updated_at.isoformat()
|
||||||
|
)
|
||||||
22
backend/services/__init__.py
Normal file
22
backend/services/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
业务服务模块
|
||||||
|
"""
|
||||||
|
from .ai_provider_service import AIProviderService
|
||||||
|
from .agent_service import AgentService
|
||||||
|
from .chatroom_service import ChatRoomService
|
||||||
|
from .message_router import MessageRouter
|
||||||
|
from .discussion_engine import DiscussionEngine
|
||||||
|
from .consensus_manager import ConsensusManager
|
||||||
|
from .mcp_service import MCPService
|
||||||
|
from .memory_service import MemoryService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AIProviderService",
|
||||||
|
"AgentService",
|
||||||
|
"ChatRoomService",
|
||||||
|
"MessageRouter",
|
||||||
|
"DiscussionEngine",
|
||||||
|
"ConsensusManager",
|
||||||
|
"MCPService",
|
||||||
|
"MemoryService",
|
||||||
|
]
|
||||||
438
backend/services/agent_service.py
Normal file
438
backend/services/agent_service.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
"""
|
||||||
|
Agent服务
|
||||||
|
管理AI代理的配置
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.agent import Agent
|
||||||
|
from services.ai_provider_service import AIProviderService
|
||||||
|
|
||||||
|
|
||||||
|
class AgentService:
|
||||||
|
"""
|
||||||
|
Agent服务类
|
||||||
|
负责Agent的CRUD操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_agent(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
role: str,
|
||||||
|
system_prompt: str,
|
||||||
|
provider_id: str,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
capabilities: Optional[Dict[str, Any]] = None,
|
||||||
|
behavior: Optional[Dict[str, Any]] = None,
|
||||||
|
avatar: Optional[str] = None,
|
||||||
|
color: str = "#1890ff"
|
||||||
|
) -> Agent:
|
||||||
|
"""
|
||||||
|
创建新的Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Agent名称
|
||||||
|
role: 角色定义
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
provider_id: 使用的AI接口ID
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大token数
|
||||||
|
capabilities: 能力配置
|
||||||
|
behavior: 行为配置
|
||||||
|
avatar: 头像URL
|
||||||
|
color: 代表颜色
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的Agent文档
|
||||||
|
"""
|
||||||
|
# 验证AI接口存在
|
||||||
|
provider = await AIProviderService.get_provider(provider_id)
|
||||||
|
if not provider:
|
||||||
|
raise ValueError(f"AI接口不存在: {provider_id}")
|
||||||
|
|
||||||
|
# 生成唯一ID
|
||||||
|
agent_id = f"agent-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# 默认能力配置
|
||||||
|
default_capabilities = {
|
||||||
|
"memory_enabled": False,
|
||||||
|
"mcp_tools": [],
|
||||||
|
"skills": [],
|
||||||
|
"multimodal": False
|
||||||
|
}
|
||||||
|
if capabilities:
|
||||||
|
default_capabilities.update(capabilities)
|
||||||
|
|
||||||
|
# 默认行为配置
|
||||||
|
default_behavior = {
|
||||||
|
"speak_threshold": 0.5,
|
||||||
|
"max_speak_per_round": 2,
|
||||||
|
"speak_style": "balanced"
|
||||||
|
}
|
||||||
|
if behavior:
|
||||||
|
default_behavior.update(behavior)
|
||||||
|
|
||||||
|
# 创建文档
|
||||||
|
agent = Agent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
name=name,
|
||||||
|
role=role,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
provider_id=provider_id,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
capabilities=default_capabilities,
|
||||||
|
behavior=default_behavior,
|
||||||
|
avatar=avatar,
|
||||||
|
color=color,
|
||||||
|
enabled=True,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.insert()
|
||||||
|
|
||||||
|
logger.info(f"创建Agent: {agent_id} ({name})")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_agent(cls, agent_id: str) -> Optional[Agent]:
|
||||||
|
"""
|
||||||
|
获取指定Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent文档或None
|
||||||
|
"""
|
||||||
|
return await Agent.find_one(Agent.agent_id == agent_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_all_agents(
|
||||||
|
cls,
|
||||||
|
enabled_only: bool = False
|
||||||
|
) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
获取所有Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled_only: 是否只返回启用的Agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent列表
|
||||||
|
"""
|
||||||
|
if enabled_only:
|
||||||
|
return await Agent.find(Agent.enabled == True).to_list()
|
||||||
|
return await Agent.find_all().to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_agents_by_ids(
|
||||||
|
cls,
|
||||||
|
agent_ids: List[str]
|
||||||
|
) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
根据ID列表获取多个Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_ids: Agent ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent列表
|
||||||
|
"""
|
||||||
|
return await Agent.find(
|
||||||
|
{"agent_id": {"$in": agent_ids}}
|
||||||
|
).to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_agent(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Agent]:
|
||||||
|
"""
|
||||||
|
更新Agent配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
**kwargs: 要更新的字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的Agent或None
|
||||||
|
"""
|
||||||
|
agent = await cls.get_agent(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果更新了provider_id,验证其存在
|
||||||
|
if "provider_id" in kwargs:
|
||||||
|
provider = await AIProviderService.get_provider(kwargs["provider_id"])
|
||||||
|
if not provider:
|
||||||
|
raise ValueError(f"AI接口不存在: {kwargs['provider_id']}")
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
kwargs["updated_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(agent, key):
|
||||||
|
setattr(agent, key, value)
|
||||||
|
|
||||||
|
await agent.save()
|
||||||
|
|
||||||
|
logger.info(f"更新Agent: {agent_id}")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_agent(cls, agent_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
agent = await cls.get_agent(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await agent.delete()
|
||||||
|
|
||||||
|
logger.info(f"删除Agent: {agent_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def test_agent(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
test_message: str = "你好,请简单介绍一下你自己。"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试Agent对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
test_message: 测试消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
agent = await cls.get_agent(agent_id)
|
||||||
|
if not agent:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Agent不存在: {agent_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not agent.enabled:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "Agent已禁用"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建消息
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": agent.system_prompt},
|
||||||
|
{"role": "user", "content": test_message}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 调用AI接口
|
||||||
|
response = await AIProviderService.chat(
|
||||||
|
provider_id=agent.provider_id,
|
||||||
|
messages=messages,
|
||||||
|
temperature=agent.temperature,
|
||||||
|
max_tokens=agent.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "测试成功",
|
||||||
|
"response": response.content,
|
||||||
|
"model": response.model,
|
||||||
|
"tokens": response.total_tokens,
|
||||||
|
"latency_ms": response.latency_ms
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def duplicate_agent(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
new_name: Optional[str] = None
|
||||||
|
) -> Optional[Agent]:
|
||||||
|
"""
|
||||||
|
复制Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: 源Agent ID
|
||||||
|
new_name: 新Agent名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新创建的Agent或None
|
||||||
|
"""
|
||||||
|
source_agent = await cls.get_agent(agent_id)
|
||||||
|
if not source_agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await cls.create_agent(
|
||||||
|
name=new_name or f"{source_agent.name} (副本)",
|
||||||
|
role=source_agent.role,
|
||||||
|
system_prompt=source_agent.system_prompt,
|
||||||
|
provider_id=source_agent.provider_id,
|
||||||
|
temperature=source_agent.temperature,
|
||||||
|
max_tokens=source_agent.max_tokens,
|
||||||
|
capabilities=source_agent.capabilities,
|
||||||
|
behavior=source_agent.behavior,
|
||||||
|
avatar=source_agent.avatar,
|
||||||
|
color=source_agent.color
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def generate_system_prompt(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
name: str,
|
||||||
|
role: str,
|
||||||
|
description: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
使用AI生成Agent系统提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: AI接口ID
|
||||||
|
name: Agent名称
|
||||||
|
role: 角色定位
|
||||||
|
description: 额外描述(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果,包含success和生成的prompt
|
||||||
|
"""
|
||||||
|
# 验证AI接口存在
|
||||||
|
provider = await AIProviderService.get_provider(provider_id)
|
||||||
|
if not provider:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"AI接口不存在: {provider_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建生成提示词的请求
|
||||||
|
generate_prompt = f"""请为一个AI Agent编写系统提示词(system prompt)。
|
||||||
|
|
||||||
|
Agent名称:{name}
|
||||||
|
角色定位:{role}
|
||||||
|
{f'补充说明:{description}' if description else ''}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 提示词应简洁专业,控制在200字以内
|
||||||
|
2. 明确该Agent的核心职责和专业领域
|
||||||
|
3. 说明在多Agent讨论中应该关注什么
|
||||||
|
4. 使用中文编写
|
||||||
|
5. 不要包含任何问候语或开场白,直接给出提示词内容
|
||||||
|
|
||||||
|
请直接输出系统提示词,不要有任何额外的解释或包装。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = [{"role": "user", "content": generate_prompt}]
|
||||||
|
|
||||||
|
response = await AIProviderService.chat(
|
||||||
|
provider_id=provider_id,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
# 清理可能的包装文本
|
||||||
|
content = response.content.strip()
|
||||||
|
# 移除可能的markdown代码块标记
|
||||||
|
if content.startswith("```"):
|
||||||
|
lines = content.split("\n")
|
||||||
|
content = "\n".join(lines[1:])
|
||||||
|
if content.endswith("```"):
|
||||||
|
content = content[:-3]
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"prompt": content,
|
||||||
|
"model": response.model,
|
||||||
|
"tokens": response.total_tokens
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": response.error or "生成失败"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成系统提示词失败: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"生成失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Agent预设模板
|
||||||
|
AGENT_TEMPLATES = {
|
||||||
|
"product_manager": {
|
||||||
|
"name": "产品经理",
|
||||||
|
"role": "产品规划和需求分析专家",
|
||||||
|
"system_prompt": """你是一位经验丰富的产品经理,擅长:
|
||||||
|
- 分析用户需求和痛点
|
||||||
|
- 制定产品策略和路线图
|
||||||
|
- 平衡业务目标和用户体验
|
||||||
|
- 与团队协作推进产品迭代
|
||||||
|
|
||||||
|
在讨论中,你需要从产品角度出发,关注用户价值、商业可行性和优先级排序。
|
||||||
|
请用专业但易懂的语言表达观点。""",
|
||||||
|
"color": "#1890ff"
|
||||||
|
},
|
||||||
|
"developer": {
|
||||||
|
"name": "开发工程师",
|
||||||
|
"role": "技术实现和架构设计专家",
|
||||||
|
"system_prompt": """你是一位资深的软件开发工程师,擅长:
|
||||||
|
- 系统架构设计
|
||||||
|
- 代码实现和优化
|
||||||
|
- 技术方案评估
|
||||||
|
- 性能和安全考量
|
||||||
|
|
||||||
|
在讨论中,你需要从技术角度出发,关注实现可行性、技术债务和最佳实践。
|
||||||
|
请提供具体的技术建议和潜在风险评估。""",
|
||||||
|
"color": "#52c41a"
|
||||||
|
},
|
||||||
|
"designer": {
|
||||||
|
"name": "设计师",
|
||||||
|
"role": "用户体验和界面设计专家",
|
||||||
|
"system_prompt": """你是一位专业的UI/UX设计师,擅长:
|
||||||
|
- 用户体验设计
|
||||||
|
- 界面视觉设计
|
||||||
|
- 交互流程优化
|
||||||
|
- 设计系统构建
|
||||||
|
|
||||||
|
在讨论中,你需要从设计角度出发,关注用户体验、视觉美感和交互流畅性。
|
||||||
|
请提供设计建议并考虑可用性和一致性。""",
|
||||||
|
"color": "#eb2f96"
|
||||||
|
},
|
||||||
|
"moderator": {
|
||||||
|
"name": "主持人",
|
||||||
|
"role": "讨论主持和共识判断专家",
|
||||||
|
"system_prompt": """你是讨论的主持人,负责:
|
||||||
|
- 引导讨论方向
|
||||||
|
- 总结各方观点
|
||||||
|
- 判断是否达成共识
|
||||||
|
- 提炼行动要点
|
||||||
|
|
||||||
|
在讨论中,你需要保持中立,促进有效沟通,并在适当时机总结讨论成果。
|
||||||
|
当各方观点趋于一致时,请明确指出并总结共识内容。""",
|
||||||
|
"color": "#722ed1"
|
||||||
|
}
|
||||||
|
}
|
||||||
364
backend/services/ai_provider_service.py
Normal file
364
backend/services/ai_provider_service.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
AI接口提供商服务
|
||||||
|
管理AI接口的配置和调用
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.ai_provider import AIProvider
|
||||||
|
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
|
||||||
|
from utils.encryption import encrypt_api_key, decrypt_api_key
|
||||||
|
from utils.rate_limiter import rate_limiter
|
||||||
|
|
||||||
|
|
||||||
|
class AIProviderService:
|
||||||
|
"""
|
||||||
|
AI接口提供商服务类
|
||||||
|
负责AI接口的CRUD操作和调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 缓存适配器实例
|
||||||
|
_adapter_cache: Dict[str, BaseAdapter] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_provider(
|
||||||
|
cls,
|
||||||
|
provider_type: str,
|
||||||
|
name: str,
|
||||||
|
model: str,
|
||||||
|
api_key: str = "",
|
||||||
|
base_url: str = "",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
rate_limit: Optional[Dict[str, int]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
extra_params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> AIProvider:
|
||||||
|
"""
|
||||||
|
创建新的AI接口配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: 提供商类型
|
||||||
|
name: 自定义名称
|
||||||
|
model: 模型名称
|
||||||
|
api_key: API密钥
|
||||||
|
base_url: API基础URL
|
||||||
|
use_proxy: 是否使用代理
|
||||||
|
proxy_config: 代理配置
|
||||||
|
rate_limit: 速率限制配置
|
||||||
|
timeout: 超时时间
|
||||||
|
extra_params: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的AIProvider文档
|
||||||
|
"""
|
||||||
|
# 验证提供商类型
|
||||||
|
try:
|
||||||
|
get_adapter(provider_type)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"不支持的提供商类型: {provider_type}")
|
||||||
|
|
||||||
|
# 生成唯一ID
|
||||||
|
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# 加密API密钥
|
||||||
|
encrypted_key = encrypt_api_key(api_key) if api_key else ""
|
||||||
|
|
||||||
|
# 创建文档
|
||||||
|
provider = AIProvider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=provider_type,
|
||||||
|
name=name,
|
||||||
|
api_key=encrypted_key,
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config or {},
|
||||||
|
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||||
|
timeout=timeout,
|
||||||
|
extra_params=extra_params or {},
|
||||||
|
enabled=True,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider.insert()
|
||||||
|
|
||||||
|
# 注册速率限制
|
||||||
|
rate_limiter.register(
|
||||||
|
provider_id,
|
||||||
|
provider.rate_limit.get("requests_per_minute", 60),
|
||||||
|
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"创建AI接口配置: {provider_id} ({name})")
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
|
||||||
|
"""
|
||||||
|
获取指定AI接口配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIProvider文档或None
|
||||||
|
"""
|
||||||
|
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_all_providers(
|
||||||
|
cls,
|
||||||
|
enabled_only: bool = False
|
||||||
|
) -> List[AIProvider]:
|
||||||
|
"""
|
||||||
|
获取所有AI接口配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled_only: 是否只返回启用的接口
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIProvider列表
|
||||||
|
"""
|
||||||
|
if enabled_only:
|
||||||
|
return await AIProvider.find(AIProvider.enabled == True).to_list()
|
||||||
|
return await AIProvider.find_all().to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_provider(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[AIProvider]:
|
||||||
|
"""
|
||||||
|
更新AI接口配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
**kwargs: 要更新的字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的AIProvider或None
|
||||||
|
"""
|
||||||
|
provider = await cls.get_provider(provider_id)
|
||||||
|
if not provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果更新了API密钥,需要加密
|
||||||
|
if "api_key" in kwargs and kwargs["api_key"]:
|
||||||
|
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
kwargs["updated_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(provider, key):
|
||||||
|
setattr(provider, key, value)
|
||||||
|
|
||||||
|
await provider.save()
|
||||||
|
|
||||||
|
# 清除适配器缓存
|
||||||
|
cls._adapter_cache.pop(provider_id, None)
|
||||||
|
|
||||||
|
# 更新速率限制
|
||||||
|
if "rate_limit" in kwargs:
|
||||||
|
rate_limiter.unregister(provider_id)
|
||||||
|
rate_limiter.register(
|
||||||
|
provider_id,
|
||||||
|
provider.rate_limit.get("requests_per_minute", 60),
|
||||||
|
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"更新AI接口配置: {provider_id}")
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_provider(cls, provider_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除AI接口配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
provider = await cls.get_provider(provider_id)
|
||||||
|
if not provider:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await provider.delete()
|
||||||
|
|
||||||
|
# 清除缓存和速率限制
|
||||||
|
cls._adapter_cache.pop(provider_id, None)
|
||||||
|
rate_limiter.unregister(provider_id)
|
||||||
|
|
||||||
|
logger.info(f"删除AI接口配置: {provider_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
|
||||||
|
"""
|
||||||
|
获取AI接口的适配器实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
适配器实例或None
|
||||||
|
"""
|
||||||
|
# 检查缓存
|
||||||
|
if provider_id in cls._adapter_cache:
|
||||||
|
return cls._adapter_cache[provider_id]
|
||||||
|
|
||||||
|
provider = await cls.get_provider(provider_id)
|
||||||
|
if not provider or not provider.enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解密API密钥
|
||||||
|
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
|
||||||
|
|
||||||
|
# 创建适配器
|
||||||
|
adapter_class = get_adapter(provider.provider_type)
|
||||||
|
adapter = adapter_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=provider.base_url,
|
||||||
|
model=provider.model,
|
||||||
|
use_proxy=provider.use_proxy,
|
||||||
|
proxy_config=provider.proxy_config,
|
||||||
|
timeout=provider.timeout,
|
||||||
|
**provider.extra_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存适配器
|
||||||
|
cls._adapter_cache[provider_id] = adapter
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def chat(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
**kwargs
|
||||||
|
) -> AdapterResponse:
|
||||||
|
"""
|
||||||
|
调用AI接口进行对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大token数
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
适配器响应
|
||||||
|
"""
|
||||||
|
adapter = await cls.get_adapter(provider_id)
|
||||||
|
if not adapter:
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error=f"AI接口不存在或未启用: {provider_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查速率限制
|
||||||
|
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||||||
|
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
|
||||||
|
return AdapterResponse(
|
||||||
|
success=False,
|
||||||
|
error="请求频率超限,请稍后重试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换消息格式
|
||||||
|
chat_messages = [
|
||||||
|
ChatMessage(
|
||||||
|
role=m.get("role", "user"),
|
||||||
|
content=m.get("content", ""),
|
||||||
|
name=m.get("name")
|
||||||
|
)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
# 调用适配器
|
||||||
|
response = await adapter.chat(
|
||||||
|
messages=chat_messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试AI接口连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 接口ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
adapter = await cls.get_adapter(provider_id)
|
||||||
|
if not adapter:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"AI接口不存在或未启用: {provider_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return await adapter.test_connection()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def test_provider_config(
|
||||||
|
cls,
|
||||||
|
provider_type: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
model: str = "",
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试AI接口配置(不保存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: 提供商类型
|
||||||
|
api_key: API密钥
|
||||||
|
base_url: API基础URL
|
||||||
|
model: 模型名称
|
||||||
|
use_proxy: 是否使用代理
|
||||||
|
proxy_config: 代理配置
|
||||||
|
timeout: 超时时间
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
adapter_class = get_adapter(provider_type)
|
||||||
|
except ValueError:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"不支持的提供商类型: {provider_type}"
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter = adapter_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
use_proxy=use_proxy,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return await adapter.test_connection()
|
||||||
357
backend/services/chatroom_service.py
Normal file
357
backend/services/chatroom_service.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
"""
|
||||||
|
聊天室服务
|
||||||
|
管理聊天室的创建和状态
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||||
|
from models.message import Message
|
||||||
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRoomService:
|
||||||
|
"""
|
||||||
|
聊天室服务类
|
||||||
|
负责聊天室的CRUD操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_chatroom(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
description: str = "",
|
||||||
|
agents: Optional[List[str]] = None,
|
||||||
|
moderator_agent_id: Optional[str] = None,
|
||||||
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
) -> ChatRoom:
|
||||||
|
"""
|
||||||
|
创建新的聊天室
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 聊天室名称
|
||||||
|
description: 描述
|
||||||
|
agents: Agent ID列表
|
||||||
|
moderator_agent_id: 主持人Agent ID
|
||||||
|
config: 聊天室配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的ChatRoom文档
|
||||||
|
"""
|
||||||
|
# 验证Agent存在
|
||||||
|
if agents:
|
||||||
|
existing_agents = await AgentService.get_agents_by_ids(agents)
|
||||||
|
existing_ids = {a.agent_id for a in existing_agents}
|
||||||
|
missing_ids = set(agents) - existing_ids
|
||||||
|
if missing_ids:
|
||||||
|
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||||
|
|
||||||
|
# 验证主持人Agent
|
||||||
|
if moderator_agent_id:
|
||||||
|
moderator = await AgentService.get_agent(moderator_agent_id)
|
||||||
|
if not moderator:
|
||||||
|
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
|
||||||
|
|
||||||
|
# 生成唯一ID
|
||||||
|
room_id = f"room-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# 默认配置
|
||||||
|
default_config = {
|
||||||
|
"max_rounds": 50,
|
||||||
|
"message_history_size": 20,
|
||||||
|
"consensus_threshold": 0.8,
|
||||||
|
"round_interval": 1.0,
|
||||||
|
"allow_user_interrupt": True
|
||||||
|
}
|
||||||
|
if config:
|
||||||
|
default_config.update(config)
|
||||||
|
|
||||||
|
# 创建文档
|
||||||
|
chatroom = ChatRoom(
|
||||||
|
room_id=room_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
objective="",
|
||||||
|
agents=agents or [],
|
||||||
|
moderator_agent_id=moderator_agent_id,
|
||||||
|
config=default_config,
|
||||||
|
status=ChatRoomStatus.IDLE.value,
|
||||||
|
current_round=0,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
await chatroom.insert()
|
||||||
|
|
||||||
|
logger.info(f"创建聊天室: {room_id} ({name})")
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
获取指定聊天室
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatRoom文档或None
|
||||||
|
"""
|
||||||
|
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_all_chatrooms(cls) -> List[ChatRoom]:
|
||||||
|
"""
|
||||||
|
获取所有聊天室
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatRoom列表
|
||||||
|
"""
|
||||||
|
return await ChatRoom.find_all().to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_chatroom(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
更新聊天室配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
**kwargs: 要更新的字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证Agent
|
||||||
|
if "agents" in kwargs:
|
||||||
|
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
|
||||||
|
existing_ids = {a.agent_id for a in existing_agents}
|
||||||
|
missing_ids = set(kwargs["agents"]) - existing_ids
|
||||||
|
if missing_ids:
|
||||||
|
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||||
|
|
||||||
|
# 验证主持人
|
||||||
|
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
|
||||||
|
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
|
||||||
|
if not moderator:
|
||||||
|
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
kwargs["updated_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(chatroom, key):
|
||||||
|
setattr(chatroom, key, value)
|
||||||
|
|
||||||
|
await chatroom.save()
|
||||||
|
|
||||||
|
logger.info(f"更新聊天室: {room_id}")
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_chatroom(cls, room_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除聊天室
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 删除相关消息
|
||||||
|
await Message.find(Message.room_id == room_id).delete()
|
||||||
|
|
||||||
|
await chatroom.delete()
|
||||||
|
|
||||||
|
logger.info(f"删除聊天室: {room_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
向聊天室添加Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证Agent存在
|
||||||
|
agent = await AgentService.get_agent(agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise ValueError(f"Agent不存在: {agent_id}")
|
||||||
|
|
||||||
|
# 添加Agent
|
||||||
|
if agent_id not in chatroom.agents:
|
||||||
|
chatroom.agents.append(agent_id)
|
||||||
|
chatroom.updated_at = datetime.utcnow()
|
||||||
|
await chatroom.save()
|
||||||
|
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
从聊天室移除Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 移除Agent
|
||||||
|
if agent_id in chatroom.agents:
|
||||||
|
chatroom.agents.remove(agent_id)
|
||||||
|
chatroom.updated_at = datetime.utcnow()
|
||||||
|
await chatroom.save()
|
||||||
|
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set_objective(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
objective: str
|
||||||
|
) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
设置讨论目标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
objective: 讨论目标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
return await cls.update_chatroom(room_id, objective=objective)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_status(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
status: ChatRoomStatus
|
||||||
|
) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
更新聊天室状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
status: 新状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chatroom.status = status.value
|
||||||
|
chatroom.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
if status == ChatRoomStatus.COMPLETED:
|
||||||
|
chatroom.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
await chatroom.save()
|
||||||
|
|
||||||
|
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
|
||||||
|
"""
|
||||||
|
增加轮次计数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的ChatRoom或None
|
||||||
|
"""
|
||||||
|
chatroom = await cls.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chatroom.current_round += 1
|
||||||
|
chatroom.updated_at = datetime.utcnow()
|
||||||
|
await chatroom.save()
|
||||||
|
|
||||||
|
return chatroom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_messages(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
skip: int = 0,
|
||||||
|
discussion_id: Optional[str] = None
|
||||||
|
) -> List[Message]:
|
||||||
|
"""
|
||||||
|
获取聊天室消息历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
skip: 跳过数量
|
||||||
|
discussion_id: 讨论ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息列表
|
||||||
|
"""
|
||||||
|
query = {"room_id": room_id}
|
||||||
|
if discussion_id:
|
||||||
|
query["discussion_id"] = discussion_id
|
||||||
|
|
||||||
|
return await Message.find(query).sort(
|
||||||
|
"-created_at"
|
||||||
|
).skip(skip).limit(limit).to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_recent_messages(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
count: int = 20,
|
||||||
|
discussion_id: Optional[str] = None
|
||||||
|
) -> List[Message]:
|
||||||
|
"""
|
||||||
|
获取最近的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
count: 消息数量
|
||||||
|
discussion_id: 讨论ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息列表(按时间正序)
|
||||||
|
"""
|
||||||
|
messages = await cls.get_messages(
|
||||||
|
room_id,
|
||||||
|
limit=count,
|
||||||
|
discussion_id=discussion_id
|
||||||
|
)
|
||||||
|
return list(reversed(messages)) # 返回正序
|
||||||
227
backend/services/consensus_manager.py
Normal file
227
backend/services/consensus_manager.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""
|
||||||
|
共识管理器
|
||||||
|
判断讨论是否达成共识
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.agent import Agent
|
||||||
|
from models.chatroom import ChatRoom
|
||||||
|
from services.ai_provider_service import AIProviderService
|
||||||
|
|
||||||
|
|
||||||
|
class ConsensusManager:
|
||||||
|
"""
|
||||||
|
共识管理器
|
||||||
|
使用主持人Agent判断讨论共识
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 共识判断提示词模板
|
||||||
|
CONSENSUS_PROMPT = """你是讨论的主持人,负责判断讨论是否达成共识。
|
||||||
|
|
||||||
|
讨论目标:{objective}
|
||||||
|
|
||||||
|
对话历史:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
请仔细分析对话内容,判断:
|
||||||
|
1. 参与者是否对核心问题达成一致意见?
|
||||||
|
2. 是否还有重要分歧未解决?
|
||||||
|
3. 讨论结果是否足够明确和可执行?
|
||||||
|
|
||||||
|
请以JSON格式回复(不要包含任何其他文字):
|
||||||
|
{{
|
||||||
|
"consensus_reached": true或false,
|
||||||
|
"confidence": 0到1之间的数字,
|
||||||
|
"summary": "讨论结果摘要,简洁概括达成的共识或当前状态",
|
||||||
|
"action_items": ["具体的行动项列表"],
|
||||||
|
"unresolved_issues": ["未解决的问题列表"],
|
||||||
|
"key_decisions": ["关键决策列表"]
|
||||||
|
}}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- consensus_reached为true表示核心问题已有明确结论
|
||||||
|
- confidence表示你对共识判断的信心程度
|
||||||
|
- 如果讨论仍有争议或不够深入,应该返回false
|
||||||
|
- action_items应该是具体可执行的任务
|
||||||
|
- 请确保返回有效的JSON格式"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_consensus(
|
||||||
|
cls,
|
||||||
|
moderator: Agent,
|
||||||
|
context: "DiscussionContext",
|
||||||
|
chatroom: ChatRoom
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
检查是否达成共识
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moderator: 主持人Agent
|
||||||
|
context: 讨论上下文
|
||||||
|
chatroom: 聊天室
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
共识判断结果
|
||||||
|
"""
|
||||||
|
from services.discussion_engine import DiscussionContext
|
||||||
|
|
||||||
|
# 构建历史记录
|
||||||
|
history_text = ""
|
||||||
|
for msg in context.messages:
|
||||||
|
if msg.agent_id:
|
||||||
|
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||||
|
else:
|
||||||
|
history_text += f"[系统]: {msg.content}\n\n"
|
||||||
|
|
||||||
|
if not history_text:
|
||||||
|
return {
|
||||||
|
"consensus_reached": False,
|
||||||
|
"confidence": 0,
|
||||||
|
"summary": "讨论尚未开始",
|
||||||
|
"action_items": [],
|
||||||
|
"unresolved_issues": [],
|
||||||
|
"key_decisions": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建提示词
|
||||||
|
prompt = cls.CONSENSUS_PROMPT.format(
|
||||||
|
objective=context.objective,
|
||||||
|
history=history_text
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用主持人Agent的AI接口
|
||||||
|
response = await AIProviderService.chat(
|
||||||
|
provider_id=moderator.provider_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.3, # 使用较低温度以获得更一致的结果
|
||||||
|
max_tokens=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
logger.error(f"共识判断失败: {response.error}")
|
||||||
|
return cls._default_result("AI接口调用失败")
|
||||||
|
|
||||||
|
# 解析JSON响应
|
||||||
|
content = response.content.strip()
|
||||||
|
|
||||||
|
# 尝试提取JSON部分
|
||||||
|
try:
|
||||||
|
# 尝试直接解析
|
||||||
|
result = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 尝试提取JSON块
|
||||||
|
import re
|
||||||
|
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
result = json.loads(json_match.group())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"无法解析共识判断结果: {content}")
|
||||||
|
return cls._default_result("无法解析AI响应")
|
||||||
|
else:
|
||||||
|
return cls._default_result("AI响应格式错误")
|
||||||
|
|
||||||
|
# 验证和规范化结果
|
||||||
|
return cls._normalize_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"共识判断异常: {e}")
|
||||||
|
return cls._default_result(str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def generate_summary(
|
||||||
|
cls,
|
||||||
|
moderator: Agent,
|
||||||
|
context: "DiscussionContext"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
生成讨论摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moderator: 主持人Agent
|
||||||
|
context: 讨论上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
讨论摘要
|
||||||
|
"""
|
||||||
|
from services.discussion_engine import DiscussionContext
|
||||||
|
|
||||||
|
# 构建历史记录
|
||||||
|
history_text = ""
|
||||||
|
for msg in context.messages:
|
||||||
|
if msg.agent_id:
|
||||||
|
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||||
|
|
||||||
|
prompt = f"""请为以下讨论生成一份简洁的摘要。
|
||||||
|
|
||||||
|
讨论目标:{context.objective}
|
||||||
|
|
||||||
|
对话记录:
|
||||||
|
{history_text}
|
||||||
|
|
||||||
|
请提供:
|
||||||
|
1. 讨论的主要观点和结论
|
||||||
|
2. 参与者的立场和建议
|
||||||
|
3. 最终的决策或共识(如果有)
|
||||||
|
|
||||||
|
摘要应该简洁明了,控制在300字以内。"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await AIProviderService.chat(
|
||||||
|
provider_id=moderator.provider_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.success:
|
||||||
|
return response.content.strip()
|
||||||
|
else:
|
||||||
|
return "无法生成摘要"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成摘要异常: {e}")
|
||||||
|
return "生成摘要时发生错误"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _default_result(cls, error: str = "") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
返回默认结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: 错误信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
默认共识结果
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"consensus_reached": False,
|
||||||
|
"confidence": 0,
|
||||||
|
"summary": error if error else "共识判断失败",
|
||||||
|
"action_items": [],
|
||||||
|
"unresolved_issues": [],
|
||||||
|
"key_decisions": []
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _normalize_result(cls, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
规范化共识结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: 原始结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
规范化的结果
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"consensus_reached": bool(result.get("consensus_reached", False)),
|
||||||
|
"confidence": max(0, min(1, float(result.get("confidence", 0)))),
|
||||||
|
"summary": str(result.get("summary", "")),
|
||||||
|
"action_items": list(result.get("action_items", [])),
|
||||||
|
"unresolved_issues": list(result.get("unresolved_issues", [])),
|
||||||
|
"key_decisions": list(result.get("key_decisions", []))
|
||||||
|
}
|
||||||
589
backend/services/discussion_engine.py
Normal file
589
backend/services/discussion_engine.py
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
"""
|
||||||
|
讨论引擎
|
||||||
|
实现自由讨论的核心逻辑
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||||
|
from models.agent import Agent
|
||||||
|
from models.message import Message, MessageType
|
||||||
|
from models.discussion_result import DiscussionResult
|
||||||
|
from services.ai_provider_service import AIProviderService
|
||||||
|
from services.agent_service import AgentService
|
||||||
|
from services.chatroom_service import ChatRoomService
|
||||||
|
from services.message_router import MessageRouter
|
||||||
|
from services.consensus_manager import ConsensusManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiscussionContext:
|
||||||
|
"""讨论上下文"""
|
||||||
|
discussion_id: str
|
||||||
|
room_id: str
|
||||||
|
objective: str
|
||||||
|
current_round: int = 0
|
||||||
|
messages: List[Message] = field(default_factory=list)
|
||||||
|
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def add_message(self, message: Message) -> None:
|
||||||
|
"""添加消息到上下文"""
|
||||||
|
self.messages.append(message)
|
||||||
|
if message.agent_id:
|
||||||
|
self.agent_speak_counts[message.agent_id] = \
|
||||||
|
self.agent_speak_counts.get(message.agent_id, 0) + 1
|
||||||
|
|
||||||
|
def get_recent_messages(self, count: int = 20) -> List[Message]:
|
||||||
|
"""获取最近的消息"""
|
||||||
|
return self.messages[-count:] if len(self.messages) > count else self.messages
|
||||||
|
|
||||||
|
def get_agent_speak_count(self, agent_id: str) -> int:
|
||||||
|
"""获取Agent在当前轮次的发言次数"""
|
||||||
|
return self.agent_speak_counts.get(agent_id, 0)
|
||||||
|
|
||||||
|
def reset_round_counts(self) -> None:
|
||||||
|
"""重置轮次发言计数"""
|
||||||
|
self.agent_speak_counts.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscussionEngine:
|
||||||
|
"""
|
||||||
|
讨论引擎
|
||||||
|
实现多Agent自由讨论的核心逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 活跃的讨论: room_id -> DiscussionContext
|
||||||
|
_active_discussions: Dict[str, DiscussionContext] = {}
|
||||||
|
|
||||||
|
# 停止信号
|
||||||
|
_stop_signals: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def start_discussion(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
objective: str
|
||||||
|
) -> Optional[DiscussionResult]:
|
||||||
|
"""
|
||||||
|
启动讨论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
objective: 讨论目标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
讨论结果
|
||||||
|
"""
|
||||||
|
# 获取聊天室
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if not chatroom:
|
||||||
|
raise ValueError(f"聊天室不存在: {room_id}")
|
||||||
|
|
||||||
|
if not chatroom.agents:
|
||||||
|
raise ValueError("聊天室没有Agent参与")
|
||||||
|
|
||||||
|
if not objective:
|
||||||
|
raise ValueError("讨论目标不能为空")
|
||||||
|
|
||||||
|
# 检查是否已有活跃讨论
|
||||||
|
if room_id in cls._active_discussions:
|
||||||
|
raise ValueError("聊天室已有进行中的讨论")
|
||||||
|
|
||||||
|
# 创建讨论
|
||||||
|
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# 创建讨论结果记录
|
||||||
|
discussion_result = DiscussionResult(
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
room_id=room_id,
|
||||||
|
objective=objective,
|
||||||
|
status="in_progress",
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
await discussion_result.insert()
|
||||||
|
|
||||||
|
# 创建讨论上下文
|
||||||
|
context = DiscussionContext(
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
room_id=room_id,
|
||||||
|
objective=objective
|
||||||
|
)
|
||||||
|
cls._active_discussions[room_id] = context
|
||||||
|
cls._stop_signals[room_id] = False
|
||||||
|
|
||||||
|
# 更新聊天室状态
|
||||||
|
await ChatRoomService.update_chatroom(
|
||||||
|
room_id,
|
||||||
|
status=ChatRoomStatus.ACTIVE.value,
|
||||||
|
objective=objective,
|
||||||
|
current_discussion_id=discussion_id,
|
||||||
|
current_round=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 广播讨论开始
|
||||||
|
await MessageRouter.broadcast_status(room_id, "discussion_started", {
|
||||||
|
"discussion_id": discussion_id,
|
||||||
|
"objective": objective
|
||||||
|
})
|
||||||
|
|
||||||
|
# 发送系统消息
|
||||||
|
await MessageRouter.save_and_broadcast_message(
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
agent_id=None,
|
||||||
|
content=f"讨论开始\n\n目标:{objective}",
|
||||||
|
message_type=MessageType.SYSTEM.value,
|
||||||
|
round_num=0
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"讨论开始: {room_id} - {objective}")
|
||||||
|
|
||||||
|
# 运行讨论循环
|
||||||
|
try:
|
||||||
|
result = await cls._run_discussion_loop(chatroom, context)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"讨论异常: {e}")
|
||||||
|
await cls._handle_discussion_error(room_id, discussion_id, str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 清理
|
||||||
|
cls._active_discussions.pop(room_id, None)
|
||||||
|
cls._stop_signals.pop(room_id, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def stop_discussion(cls, room_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
停止讨论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if room_id not in cls._active_discussions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cls._stop_signals[room_id] = True
|
||||||
|
logger.info(f"收到停止讨论信号: {room_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def pause_discussion(cls, room_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
暂停讨论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if room_id not in cls._active_discussions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
|
||||||
|
await MessageRouter.broadcast_status(room_id, "discussion_paused")
|
||||||
|
|
||||||
|
logger.info(f"讨论暂停: {room_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resume_discussion(cls, room_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
恢复讨论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
|
||||||
|
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
|
||||||
|
|
||||||
|
logger.info(f"讨论恢复: {room_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _run_discussion_loop(
|
||||||
|
cls,
|
||||||
|
chatroom: ChatRoom,
|
||||||
|
context: DiscussionContext
|
||||||
|
) -> DiscussionResult:
|
||||||
|
"""
|
||||||
|
运行讨论循环
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom: 聊天室
|
||||||
|
context: 讨论上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
讨论结果
|
||||||
|
"""
|
||||||
|
room_id = chatroom.room_id
|
||||||
|
config = chatroom.get_config()
|
||||||
|
|
||||||
|
# 获取所有Agent
|
||||||
|
agents = await AgentService.get_agents_by_ids(chatroom.agents)
|
||||||
|
agent_map = {a.agent_id: a for a in agents}
|
||||||
|
|
||||||
|
# 获取主持人(用于共识判断)
|
||||||
|
moderator = None
|
||||||
|
if chatroom.moderator_agent_id:
|
||||||
|
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
|
||||||
|
|
||||||
|
consecutive_no_speak = 0 # 连续无人发言的轮次
|
||||||
|
|
||||||
|
while context.current_round < config.max_rounds:
|
||||||
|
# 检查停止信号
|
||||||
|
if cls._stop_signals.get(room_id, False):
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查暂停状态
|
||||||
|
current_chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||||
|
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 增加轮次
|
||||||
|
context.current_round += 1
|
||||||
|
context.reset_round_counts()
|
||||||
|
|
||||||
|
# 广播轮次信息
|
||||||
|
await MessageRouter.broadcast_round_info(
|
||||||
|
room_id,
|
||||||
|
context.current_round,
|
||||||
|
config.max_rounds
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新聊天室轮次
|
||||||
|
await ChatRoomService.update_chatroom(
|
||||||
|
room_id,
|
||||||
|
current_round=context.current_round
|
||||||
|
)
|
||||||
|
|
||||||
|
# 本轮是否有人发言
|
||||||
|
round_has_message = False
|
||||||
|
|
||||||
|
# 遍历所有Agent,判断是否发言
|
||||||
|
for agent_id in chatroom.agents:
|
||||||
|
agent = agent_map.get(agent_id)
|
||||||
|
if not agent or not agent.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查本轮发言次数限制
|
||||||
|
behavior = agent.get_behavior()
|
||||||
|
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 判断是否发言
|
||||||
|
should_speak, content = await cls._should_agent_speak(
|
||||||
|
agent, context, chatroom
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_speak and content:
|
||||||
|
# 广播输入状态
|
||||||
|
await MessageRouter.broadcast_typing(room_id, agent_id, True)
|
||||||
|
|
||||||
|
# 保存并广播消息
|
||||||
|
message = await MessageRouter.save_and_broadcast_message(
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=context.discussion_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
content=content,
|
||||||
|
message_type=MessageType.TEXT.value,
|
||||||
|
round_num=context.current_round
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新上下文
|
||||||
|
context.add_message(message)
|
||||||
|
round_has_message = True
|
||||||
|
|
||||||
|
# 广播输入结束
|
||||||
|
await MessageRouter.broadcast_typing(room_id, agent_id, False)
|
||||||
|
|
||||||
|
# 轮次间隔
|
||||||
|
await asyncio.sleep(config.round_interval)
|
||||||
|
|
||||||
|
# 检查是否需要共识判断
|
||||||
|
if round_has_message and moderator:
|
||||||
|
consecutive_no_speak = 0
|
||||||
|
|
||||||
|
# 每隔几轮检查一次共识
|
||||||
|
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
|
||||||
|
consensus_result = await ConsensusManager.check_consensus(
|
||||||
|
moderator, context, chatroom
|
||||||
|
)
|
||||||
|
|
||||||
|
if consensus_result.get("consensus_reached", False):
|
||||||
|
confidence = consensus_result.get("confidence", 0)
|
||||||
|
if confidence >= config.consensus_threshold:
|
||||||
|
# 达成共识,结束讨论
|
||||||
|
return await cls._finalize_discussion(
|
||||||
|
context,
|
||||||
|
consensus_result,
|
||||||
|
"consensus"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
consecutive_no_speak += 1
|
||||||
|
|
||||||
|
# 连续多轮无人发言,检查共识或结束
|
||||||
|
if consecutive_no_speak >= 3:
|
||||||
|
if moderator:
|
||||||
|
consensus_result = await ConsensusManager.check_consensus(
|
||||||
|
moderator, context, chatroom
|
||||||
|
)
|
||||||
|
return await cls._finalize_discussion(
|
||||||
|
context,
|
||||||
|
consensus_result,
|
||||||
|
"no_more_discussion"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await cls._finalize_discussion(
|
||||||
|
context,
|
||||||
|
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
|
||||||
|
"no_more_discussion"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 达到最大轮次
|
||||||
|
if moderator:
|
||||||
|
consensus_result = await ConsensusManager.check_consensus(
|
||||||
|
moderator, context, chatroom
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
|
||||||
|
|
||||||
|
return await cls._finalize_discussion(
|
||||||
|
context,
|
||||||
|
consensus_result,
|
||||||
|
"max_rounds"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _should_agent_speak(
|
||||||
|
cls,
|
||||||
|
agent: Agent,
|
||||||
|
context: DiscussionContext,
|
||||||
|
chatroom: ChatRoom
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
判断Agent是否应该发言
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent实例
|
||||||
|
context: 讨论上下文
|
||||||
|
chatroom: 聊天室
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否发言, 发言内容)
|
||||||
|
"""
|
||||||
|
# 构建判断提示词
|
||||||
|
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
|
||||||
|
|
||||||
|
history_text = ""
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.agent_id:
|
||||||
|
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||||
|
else:
|
||||||
|
history_text += f"[系统]: {msg.content}\n\n"
|
||||||
|
|
||||||
|
prompt = f"""你是{agent.name},角色是{agent.role}。
|
||||||
|
|
||||||
|
{agent.system_prompt}
|
||||||
|
|
||||||
|
当前讨论目标:{context.objective}
|
||||||
|
|
||||||
|
对话历史:
|
||||||
|
{history_text if history_text else "(还没有对话)"}
|
||||||
|
|
||||||
|
当前是第{context.current_round}轮讨论。
|
||||||
|
|
||||||
|
请根据你的角色判断:
|
||||||
|
1. 你是否有新的观点或建议要分享?
|
||||||
|
2. 你是否需要回应其他人的观点?
|
||||||
|
3. 当前讨论是否需要你的专业意见?
|
||||||
|
|
||||||
|
如果你认为需要发言,请直接给出你的发言内容。
|
||||||
|
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"。
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 请保持发言简洁有力,每次发言控制在200字以内
|
||||||
|
- 避免重复已经说过的内容
|
||||||
|
- 如果已经达成共识或接近共识,可以选择PASS"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用AI接口
|
||||||
|
response = await AIProviderService.chat(
|
||||||
|
provider_id=agent.provider_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=agent.temperature,
|
||||||
|
max_tokens=agent.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
content = response.content.strip()
|
||||||
|
|
||||||
|
# 判断是否PASS
|
||||||
|
if content.upper() == "PASS" or content.upper().startswith("PASS"):
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
return True, content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _finalize_discussion(
|
||||||
|
cls,
|
||||||
|
context: DiscussionContext,
|
||||||
|
consensus_result: Dict[str, Any],
|
||||||
|
end_reason: str
|
||||||
|
) -> DiscussionResult:
|
||||||
|
"""
|
||||||
|
完成讨论,保存结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 讨论上下文
|
||||||
|
consensus_result: 共识判断结果
|
||||||
|
end_reason: 结束原因
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
讨论结果
|
||||||
|
"""
|
||||||
|
room_id = context.room_id
|
||||||
|
|
||||||
|
# 获取讨论结果记录
|
||||||
|
discussion_result = await DiscussionResult.find_one(
|
||||||
|
DiscussionResult.discussion_id == context.discussion_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if discussion_result:
|
||||||
|
# 更新统计
|
||||||
|
discussion_result.update_stats(
|
||||||
|
total_rounds=context.current_round,
|
||||||
|
total_messages=len(context.messages),
|
||||||
|
agent_contributions=context.agent_speak_counts
|
||||||
|
)
|
||||||
|
|
||||||
|
# 标记完成
|
||||||
|
discussion_result.mark_completed(
|
||||||
|
consensus_reached=consensus_result.get("consensus_reached", False),
|
||||||
|
confidence=consensus_result.get("confidence", 0),
|
||||||
|
summary=consensus_result.get("summary", ""),
|
||||||
|
action_items=consensus_result.get("action_items", []),
|
||||||
|
unresolved_issues=consensus_result.get("unresolved_issues", []),
|
||||||
|
end_reason=end_reason
|
||||||
|
)
|
||||||
|
|
||||||
|
await discussion_result.save()
|
||||||
|
|
||||||
|
# 更新聊天室状态
|
||||||
|
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
|
||||||
|
|
||||||
|
# 发送系统消息
|
||||||
|
summary_text = f"""讨论结束
|
||||||
|
|
||||||
|
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
|
||||||
|
置信度:{consensus_result.get("confidence", 0):.0%}
|
||||||
|
|
||||||
|
摘要:{consensus_result.get("summary", "无")}
|
||||||
|
|
||||||
|
行动项:
|
||||||
|
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or "无"}
|
||||||
|
|
||||||
|
未解决问题:
|
||||||
|
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or "无"}
|
||||||
|
|
||||||
|
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
|
||||||
|
|
||||||
|
await MessageRouter.save_and_broadcast_message(
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=context.discussion_id,
|
||||||
|
agent_id=None,
|
||||||
|
content=summary_text,
|
||||||
|
message_type=MessageType.SYSTEM.value,
|
||||||
|
round_num=context.current_round
|
||||||
|
)
|
||||||
|
|
||||||
|
# 广播讨论结束
|
||||||
|
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
|
||||||
|
"discussion_id": context.discussion_id,
|
||||||
|
"consensus_reached": consensus_result.get("consensus_reached", False),
|
||||||
|
"end_reason": end_reason
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
|
||||||
|
|
||||||
|
return discussion_result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _handle_discussion_error(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
discussion_id: str,
|
||||||
|
error: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
处理讨论错误
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
discussion_id: 讨论ID
|
||||||
|
error: 错误信息
|
||||||
|
"""
|
||||||
|
# 更新聊天室状态
|
||||||
|
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
|
||||||
|
|
||||||
|
# 更新讨论结果
|
||||||
|
discussion_result = await DiscussionResult.find_one(
|
||||||
|
DiscussionResult.discussion_id == discussion_id
|
||||||
|
)
|
||||||
|
if discussion_result:
|
||||||
|
discussion_result.status = "failed"
|
||||||
|
discussion_result.end_reason = f"error: {error}"
|
||||||
|
discussion_result.updated_at = datetime.utcnow()
|
||||||
|
await discussion_result.save()
|
||||||
|
|
||||||
|
# 广播错误
|
||||||
|
await MessageRouter.broadcast_error(room_id, error)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
|
||||||
|
"""
|
||||||
|
获取活跃的讨论上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
讨论上下文或None
|
||||||
|
"""
|
||||||
|
return cls._active_discussions.get(room_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_discussion_active(cls, room_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查是否有活跃讨论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否活跃
|
||||||
|
"""
|
||||||
|
return room_id in cls._active_discussions
|
||||||
252
backend/services/mcp_service.py
Normal file
252
backend/services/mcp_service.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
MCP服务
|
||||||
|
管理MCP工具的集成和调用
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class MCPService:
|
||||||
|
"""
|
||||||
|
MCP工具服务
|
||||||
|
集成MCP服务器,提供工具调用能力
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MCP服务器配置目录
|
||||||
|
MCP_CONFIG_DIR = Path(os.getenv("CURSOR_MCP_DIR", "~/.cursor/mcps")).expanduser()
|
||||||
|
|
||||||
|
# 已注册的工具: server_name -> List[tool_info]
|
||||||
|
_registered_tools: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
# Agent工具映射: agent_id -> List[tool_name]
|
||||||
|
_agent_tools: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def initialize(cls) -> None:
|
||||||
|
"""
|
||||||
|
初始化MCP服务
|
||||||
|
扫描并注册可用的MCP工具
|
||||||
|
"""
|
||||||
|
logger.info("初始化MCP服务...")
|
||||||
|
|
||||||
|
if not cls.MCP_CONFIG_DIR.exists():
|
||||||
|
logger.warning(f"MCP配置目录不存在: {cls.MCP_CONFIG_DIR}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 扫描MCP服务器目录
|
||||||
|
for server_dir in cls.MCP_CONFIG_DIR.iterdir():
|
||||||
|
if server_dir.is_dir():
|
||||||
|
await cls._scan_server(server_dir)
|
||||||
|
|
||||||
|
logger.info(f"MCP服务初始化完成,已注册 {len(cls._registered_tools)} 个服务器")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _scan_server(cls, server_dir: Path) -> None:
|
||||||
|
"""
|
||||||
|
扫描MCP服务器目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_dir: 服务器目录
|
||||||
|
"""
|
||||||
|
server_name = server_dir.name
|
||||||
|
tools_dir = server_dir / "tools"
|
||||||
|
|
||||||
|
if not tools_dir.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_file in tools_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(tool_file, "r", encoding="utf-8") as f:
|
||||||
|
tool_info = json.load(f)
|
||||||
|
tool_info["_file"] = str(tool_file)
|
||||||
|
tools.append(tool_info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载MCP工具配置失败: {tool_file} - {e}")
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
cls._registered_tools[server_name] = tools
|
||||||
|
logger.debug(f"注册MCP服务器: {server_name}, 工具数: {len(tools)}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_servers(cls) -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有可用的MCP服务器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
服务器名称列表
|
||||||
|
"""
|
||||||
|
return list(cls._registered_tools.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_tools(cls, server: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
列出可用的MCP工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: 服务器名称(可选,不指定则返回所有)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具信息列表
|
||||||
|
"""
|
||||||
|
if server:
|
||||||
|
return cls._registered_tools.get(server, [])
|
||||||
|
|
||||||
|
# 返回所有工具
|
||||||
|
all_tools = []
|
||||||
|
for server_name, tools in cls._registered_tools.items():
|
||||||
|
for tool in tools:
|
||||||
|
tool_copy = tool.copy()
|
||||||
|
tool_copy["server"] = server_name
|
||||||
|
all_tools.append(tool_copy)
|
||||||
|
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tool(cls, server: str, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定工具的信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: 服务器名称
|
||||||
|
tool_name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具信息或None
|
||||||
|
"""
|
||||||
|
tools = cls._registered_tools.get(server, [])
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("name") == tool_name:
|
||||||
|
return tool
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def call_tool(
|
||||||
|
cls,
|
||||||
|
server: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
调用MCP工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: 服务器名称
|
||||||
|
tool_name: 工具名称
|
||||||
|
arguments: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
调用结果
|
||||||
|
"""
|
||||||
|
tool = cls.get_tool(server, tool_name)
|
||||||
|
if not tool:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"工具不存在: {server}/{tool_name}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: 实际的MCP工具调用逻辑
|
||||||
|
# 这里需要根据MCP协议实现工具调用
|
||||||
|
# 目前返回模拟结果
|
||||||
|
logger.info(f"调用MCP工具: {server}/{tool_name}, 参数: {arguments}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"result": f"MCP工具调用: {tool_name}",
|
||||||
|
"tool": tool_name,
|
||||||
|
"server": server,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_tool_for_agent(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
tool_name: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
为Agent注册可用工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
tool_name: 工具名称(格式: server/tool_name)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否注册成功
|
||||||
|
"""
|
||||||
|
if agent_id not in cls._agent_tools:
|
||||||
|
cls._agent_tools[agent_id] = []
|
||||||
|
|
||||||
|
if tool_name not in cls._agent_tools[agent_id]:
|
||||||
|
cls._agent_tools[agent_id].append(tool_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister_tool_for_agent(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
tool_name: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
为Agent注销工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
tool_name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否注销成功
|
||||||
|
"""
|
||||||
|
if agent_id in cls._agent_tools:
|
||||||
|
if tool_name in cls._agent_tools[agent_id]:
|
||||||
|
cls._agent_tools[agent_id].remove(tool_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_agent_tools(cls, agent_id: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取Agent可用的工具列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具名称列表
|
||||||
|
"""
|
||||||
|
return cls._agent_tools.get(agent_id, [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tools_for_prompt(cls, agent_id: str) -> str:
|
||||||
|
"""
|
||||||
|
获取用于提示词的工具描述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具描述文本
|
||||||
|
"""
|
||||||
|
tool_names = cls.get_agent_tools(agent_id)
|
||||||
|
if not tool_names:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
descriptions = []
|
||||||
|
for full_name in tool_names:
|
||||||
|
parts = full_name.split("/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
server, tool_name = parts
|
||||||
|
tool = cls.get_tool(server, tool_name)
|
||||||
|
if tool:
|
||||||
|
desc = tool.get("description", "无描述")
|
||||||
|
descriptions.append(f"- {tool_name}: {desc}")
|
||||||
|
|
||||||
|
if not descriptions:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "你可以使用以下工具:\n" + "\n".join(descriptions)
|
||||||
416
backend/services/memory_service.py
Normal file
416
backend/services/memory_service.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""
|
||||||
|
记忆服务
|
||||||
|
管理Agent的记忆存储和检索
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from models.agent_memory import AgentMemory, MemoryType
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryService:
|
||||||
|
"""
|
||||||
|
Agent记忆服务
|
||||||
|
提供记忆的存储、检索和管理功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 嵌入模型(延迟加载)
|
||||||
|
_embedding_model = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_embedding_model(cls):
|
||||||
|
"""
|
||||||
|
获取嵌入模型实例(延迟加载)
|
||||||
|
"""
|
||||||
|
if cls._embedding_model is None:
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
||||||
|
logger.info("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"嵌入模型加载失败: {e}")
|
||||||
|
return None
|
||||||
|
return cls._embedding_model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_memory(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
content: str,
|
||||||
|
memory_type: str = MemoryType.SHORT_TERM.value,
|
||||||
|
importance: float = 0.5,
|
||||||
|
source_room_id: Optional[str] = None,
|
||||||
|
source_discussion_id: Optional[str] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
expires_in_hours: Optional[int] = None
|
||||||
|
) -> AgentMemory:
|
||||||
|
"""
|
||||||
|
创建新的记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
content: 记忆内容
|
||||||
|
memory_type: 记忆类型
|
||||||
|
importance: 重要性评分
|
||||||
|
source_room_id: 来源聊天室
|
||||||
|
source_discussion_id: 来源讨论
|
||||||
|
tags: 标签
|
||||||
|
expires_in_hours: 过期时间(小时)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的AgentMemory文档
|
||||||
|
"""
|
||||||
|
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
# 生成向量嵌入
|
||||||
|
embedding = await cls._generate_embedding(content)
|
||||||
|
|
||||||
|
# 生成摘要
|
||||||
|
summary = content[:100] + "..." if len(content) > 100 else content
|
||||||
|
|
||||||
|
# 计算过期时间
|
||||||
|
expires_at = None
|
||||||
|
if expires_in_hours:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||||
|
|
||||||
|
memory = AgentMemory(
|
||||||
|
memory_id=memory_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
memory_type=memory_type,
|
||||||
|
content=content,
|
||||||
|
summary=summary,
|
||||||
|
embedding=embedding,
|
||||||
|
importance=importance,
|
||||||
|
source_room_id=source_room_id,
|
||||||
|
source_discussion_id=source_discussion_id,
|
||||||
|
tags=tags or [],
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
last_accessed=datetime.utcnow(),
|
||||||
|
expires_at=expires_at
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory.insert()
|
||||||
|
|
||||||
|
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
|
||||||
|
return memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
|
||||||
|
"""
|
||||||
|
获取指定记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentMemory文档或None
|
||||||
|
"""
|
||||||
|
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_agent_memories(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
memory_type: Optional[str] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> List[AgentMemory]:
|
||||||
|
"""
|
||||||
|
获取Agent的记忆列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
memory_type: 记忆类型(可选)
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记忆列表
|
||||||
|
"""
|
||||||
|
query = {"agent_id": agent_id}
|
||||||
|
if memory_type:
|
||||||
|
query["memory_type"] = memory_type
|
||||||
|
|
||||||
|
return await AgentMemory.find(query).sort(
|
||||||
|
"-importance", "-last_accessed"
|
||||||
|
).limit(limit).to_list()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def search_memories(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
memory_type: Optional[str] = None,
|
||||||
|
min_relevance: float = 0.3
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
搜索相关记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
query: 查询文本
|
||||||
|
limit: 返回数量
|
||||||
|
memory_type: 记忆类型(可选)
|
||||||
|
min_relevance: 最小相关性阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
带相关性分数的记忆列表
|
||||||
|
"""
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = await cls._generate_embedding(query)
|
||||||
|
if not query_embedding:
|
||||||
|
# 无法生成向量时,使用文本匹配
|
||||||
|
return await cls._text_search(agent_id, query, limit, memory_type)
|
||||||
|
|
||||||
|
# 获取Agent的所有记忆
|
||||||
|
filter_query = {"agent_id": agent_id}
|
||||||
|
if memory_type:
|
||||||
|
filter_query["memory_type"] = memory_type
|
||||||
|
|
||||||
|
memories = await AgentMemory.find(filter_query).to_list()
|
||||||
|
|
||||||
|
# 计算相似度
|
||||||
|
results = []
|
||||||
|
for memory in memories:
|
||||||
|
if memory.is_expired():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if memory.embedding:
|
||||||
|
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
|
||||||
|
relevance = memory.calculate_relevance_score(similarity)
|
||||||
|
|
||||||
|
if relevance >= min_relevance:
|
||||||
|
results.append({
|
||||||
|
"memory": memory,
|
||||||
|
"similarity": similarity,
|
||||||
|
"relevance": relevance
|
||||||
|
})
|
||||||
|
|
||||||
|
# 按相关性排序
|
||||||
|
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||||
|
|
||||||
|
# 更新访问记录
|
||||||
|
for item in results[:limit]:
|
||||||
|
memory = item["memory"]
|
||||||
|
memory.access()
|
||||||
|
await memory.save()
|
||||||
|
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_memory(
|
||||||
|
cls,
|
||||||
|
memory_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[AgentMemory]:
|
||||||
|
"""
|
||||||
|
更新记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
**kwargs: 要更新的字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的AgentMemory或None
|
||||||
|
"""
|
||||||
|
memory = await cls.get_memory(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果更新了内容,重新生成嵌入
|
||||||
|
if "content" in kwargs:
|
||||||
|
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
|
||||||
|
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(memory, key):
|
||||||
|
setattr(memory, key, value)
|
||||||
|
|
||||||
|
await memory.save()
|
||||||
|
return memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_memory(cls, memory_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: 记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
memory = await cls.get_memory(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await memory.delete()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_agent_memories(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
memory_type: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
删除Agent的记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
memory_type: 记忆类型(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的数量
|
||||||
|
"""
|
||||||
|
query = {"agent_id": agent_id}
|
||||||
|
if memory_type:
|
||||||
|
query["memory_type"] = memory_type
|
||||||
|
|
||||||
|
result = await AgentMemory.find(query).delete()
|
||||||
|
return result.deleted_count if result else 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def cleanup_expired_memories(cls) -> int:
|
||||||
|
"""
|
||||||
|
清理过期的记忆
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理的数量
|
||||||
|
"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
result = await AgentMemory.find(
|
||||||
|
{"expires_at": {"$lt": now}}
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
count = result.deleted_count if result else 0
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"清理了 {count} 条过期记忆")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def consolidate_memories(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
min_importance: float = 0.7,
|
||||||
|
max_age_days: int = 30
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
整合记忆(将重要的短期记忆转为长期记忆)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
min_importance: 最小重要性阈值
|
||||||
|
max_age_days: 最大年龄(天)
|
||||||
|
"""
|
||||||
|
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
|
||||||
|
|
||||||
|
# 查找符合条件的短期记忆
|
||||||
|
memories = await AgentMemory.find({
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"memory_type": MemoryType.SHORT_TERM.value,
|
||||||
|
"importance": {"$gte": min_importance},
|
||||||
|
"created_at": {"$lt": cutoff_date}
|
||||||
|
}).to_list()
|
||||||
|
|
||||||
|
for memory in memories:
|
||||||
|
memory.memory_type = MemoryType.LONG_TERM.value
|
||||||
|
memory.expires_at = None # 长期记忆不过期
|
||||||
|
await memory.save()
|
||||||
|
|
||||||
|
if memories:
|
||||||
|
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _generate_embedding(cls, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
生成文本的向量嵌入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
向量嵌入列表
|
||||||
|
"""
|
||||||
|
model = cls._get_embedding_model()
|
||||||
|
if model is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = model.encode(text, convert_to_numpy=True)
|
||||||
|
return embedding.tolist()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"生成嵌入失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
|
||||||
|
"""
|
||||||
|
计算余弦相似度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vec1: 向量1
|
||||||
|
vec2: 向量2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似度 (0-1)
|
||||||
|
"""
|
||||||
|
if not vec1 or not vec2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
a = np.array(vec1)
|
||||||
|
b = np.array(vec2)
|
||||||
|
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||||
|
return float(max(0, similarity))
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _text_search(
|
||||||
|
cls,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int,
|
||||||
|
memory_type: Optional[str]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
文本搜索(后备方案)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
query: 查询文本
|
||||||
|
limit: 返回数量
|
||||||
|
memory_type: 记忆类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记忆列表
|
||||||
|
"""
|
||||||
|
filter_query = {"agent_id": agent_id}
|
||||||
|
if memory_type:
|
||||||
|
filter_query["memory_type"] = memory_type
|
||||||
|
|
||||||
|
# 简单的文本匹配
|
||||||
|
memories = await AgentMemory.find(filter_query).to_list()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
query_lower = query.lower()
|
||||||
|
for memory in memories:
|
||||||
|
if memory.is_expired():
|
||||||
|
continue
|
||||||
|
|
||||||
|
content_lower = memory.content.lower()
|
||||||
|
if query_lower in content_lower:
|
||||||
|
# 计算简单的匹配分数
|
||||||
|
score = len(query_lower) / len(content_lower)
|
||||||
|
results.append({
|
||||||
|
"memory": memory,
|
||||||
|
"similarity": score,
|
||||||
|
"relevance": score * memory.importance
|
||||||
|
})
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||||
|
return results[:limit]
|
||||||
335
backend/services/message_router.py
Normal file
335
backend/services/message_router.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""
|
||||||
|
消息路由服务
|
||||||
|
管理消息的发送和广播
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, Callable, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
from models.message import Message, MessageType
|
||||||
|
from models.chatroom import ChatRoom
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WebSocketConnection:
|
||||||
|
"""WebSocket连接信息"""
|
||||||
|
websocket: WebSocket
|
||||||
|
room_id: str
|
||||||
|
connected_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRouter:
|
||||||
|
"""
|
||||||
|
消息路由器
|
||||||
|
管理WebSocket连接和消息广播
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 房间连接映射: room_id -> Set[WebSocket]
|
||||||
|
_room_connections: Dict[str, Set[WebSocket]] = {}
|
||||||
|
|
||||||
|
# 消息回调: 用于外部订阅消息
|
||||||
|
_message_callbacks: List[Callable] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||||
|
"""
|
||||||
|
建立WebSocket连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
websocket: WebSocket实例
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
if room_id not in cls._room_connections:
|
||||||
|
cls._room_connections[room_id] = set()
|
||||||
|
|
||||||
|
cls._room_connections[room_id].add(websocket)
|
||||||
|
|
||||||
|
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||||
|
"""
|
||||||
|
断开WebSocket连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
websocket: WebSocket实例
|
||||||
|
"""
|
||||||
|
if room_id in cls._room_connections:
|
||||||
|
cls._room_connections[room_id].discard(websocket)
|
||||||
|
|
||||||
|
# 清理空房间
|
||||||
|
if not cls._room_connections[room_id]:
|
||||||
|
del cls._room_connections[room_id]
|
||||||
|
|
||||||
|
logger.info(f"WebSocket连接断开: {room_id}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def broadcast_to_room(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
message: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
向聊天室广播消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
message: 消息内容
|
||||||
|
"""
|
||||||
|
if room_id not in cls._room_connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取所有连接
|
||||||
|
connections = cls._room_connections[room_id].copy()
|
||||||
|
|
||||||
|
# 并发发送
|
||||||
|
tasks = []
|
||||||
|
for websocket in connections:
|
||||||
|
tasks.append(cls._send_message(room_id, websocket, message))
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _send_message(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
message: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
向单个WebSocket发送消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
websocket: WebSocket实例
|
||||||
|
message: 消息内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"WebSocket发送失败: {e}")
|
||||||
|
# 移除断开的连接
|
||||||
|
await cls.disconnect(room_id, websocket)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def save_and_broadcast_message(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
discussion_id: str,
|
||||||
|
agent_id: Optional[str],
|
||||||
|
content: str,
|
||||||
|
message_type: str = MessageType.TEXT.value,
|
||||||
|
round_num: int = 0,
|
||||||
|
attachments: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_results: Optional[List[Dict[str, Any]]] = None
|
||||||
|
) -> Message:
|
||||||
|
"""
|
||||||
|
保存消息并广播
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
discussion_id: 讨论ID
|
||||||
|
agent_id: 发送Agent ID
|
||||||
|
content: 消息内容
|
||||||
|
message_type: 消息类型
|
||||||
|
round_num: 轮次号
|
||||||
|
attachments: 附件
|
||||||
|
tool_calls: 工具调用
|
||||||
|
tool_results: 工具结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保存的Message文档
|
||||||
|
"""
|
||||||
|
# 创建消息
|
||||||
|
message = Message(
|
||||||
|
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
||||||
|
room_id=room_id,
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
content=content,
|
||||||
|
message_type=message_type,
|
||||||
|
attachments=attachments or [],
|
||||||
|
round=round_num,
|
||||||
|
token_count=len(content) // 4, # 粗略估计
|
||||||
|
tool_calls=tool_calls or [],
|
||||||
|
tool_results=tool_results or [],
|
||||||
|
created_at=datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
await message.insert()
|
||||||
|
|
||||||
|
# 构建广播消息
|
||||||
|
broadcast_data = {
|
||||||
|
"type": "message",
|
||||||
|
"data": {
|
||||||
|
"message_id": message.message_id,
|
||||||
|
"room_id": message.room_id,
|
||||||
|
"discussion_id": message.discussion_id,
|
||||||
|
"agent_id": message.agent_id,
|
||||||
|
"content": message.content,
|
||||||
|
"message_type": message.message_type,
|
||||||
|
"round": message.round,
|
||||||
|
"created_at": message.created_at.isoformat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 广播消息
|
||||||
|
await cls.broadcast_to_room(room_id, broadcast_data)
|
||||||
|
|
||||||
|
# 触发回调
|
||||||
|
for callback in cls._message_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"消息回调执行失败: {e}")
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def broadcast_status(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
status: str,
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
广播状态更新
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
status: 状态类型
|
||||||
|
data: 附加数据
|
||||||
|
"""
|
||||||
|
message = {
|
||||||
|
"type": "status",
|
||||||
|
"status": status,
|
||||||
|
"data": data or {},
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await cls.broadcast_to_room(room_id, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def broadcast_typing(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
is_typing: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
广播Agent输入状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
agent_id: Agent ID
|
||||||
|
is_typing: 是否正在输入
|
||||||
|
"""
|
||||||
|
message = {
|
||||||
|
"type": "typing",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"is_typing": is_typing,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await cls.broadcast_to_room(room_id, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def broadcast_round_info(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
round_num: int,
|
||||||
|
total_rounds: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
广播轮次信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
round_num: 当前轮次
|
||||||
|
total_rounds: 最大轮次
|
||||||
|
"""
|
||||||
|
message = {
|
||||||
|
"type": "round",
|
||||||
|
"round": round_num,
|
||||||
|
"total_rounds": total_rounds,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await cls.broadcast_to_room(room_id, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def broadcast_error(
|
||||||
|
cls,
|
||||||
|
room_id: str,
|
||||||
|
error: str,
|
||||||
|
agent_id: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
广播错误信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
error: 错误信息
|
||||||
|
agent_id: 相关Agent ID
|
||||||
|
"""
|
||||||
|
message = {
|
||||||
|
"type": "error",
|
||||||
|
"error": error,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
await cls.broadcast_to_room(room_id, message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_callback(cls, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
注册消息回调
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 回调函数,接收Message参数
|
||||||
|
"""
|
||||||
|
cls._message_callbacks.append(callback)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister_callback(cls, callback: Callable) -> None:
|
||||||
|
"""
|
||||||
|
注销消息回调
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 回调函数
|
||||||
|
"""
|
||||||
|
if callback in cls._message_callbacks:
|
||||||
|
cls._message_callbacks.remove(callback)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connection_count(cls, room_id: str) -> int:
|
||||||
|
"""
|
||||||
|
获取房间连接数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: 聊天室ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
连接数
|
||||||
|
"""
|
||||||
|
return len(cls._room_connections.get(room_id, set()))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_room_ids(cls) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取所有活跃房间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
房间ID列表
|
||||||
|
"""
|
||||||
|
return list(cls._room_connections.keys())
|
||||||
13
backend/utils/__init__.py
Normal file
13
backend/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
工具函数模块
|
||||||
|
"""
|
||||||
|
from .encryption import encrypt_api_key, decrypt_api_key
|
||||||
|
from .proxy_handler import get_http_client
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"encrypt_api_key",
|
||||||
|
"decrypt_api_key",
|
||||||
|
"get_http_client",
|
||||||
|
"RateLimiter",
|
||||||
|
]
|
||||||
97
backend/utils/encryption.py
Normal file
97
backend/utils/encryption.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
加密工具模块
|
||||||
|
用于API密钥的加密和解密
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
"""
|
||||||
|
获取Fernet加密器实例
|
||||||
|
使用配置的加密密钥派生加密密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fernet加密器
|
||||||
|
"""
|
||||||
|
# 使用PBKDF2从密钥派生32字节密钥
|
||||||
|
salt = b"ai_chatroom_salt" # 固定salt,实际生产环境应使用随机salt
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
salt=salt,
|
||||||
|
iterations=100000,
|
||||||
|
)
|
||||||
|
key = base64.urlsafe_b64encode(
|
||||||
|
kdf.derive(settings.ENCRYPTION_KEY.encode())
|
||||||
|
)
|
||||||
|
return Fernet(key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_api_key(api_key: str) -> str:
|
||||||
|
"""
|
||||||
|
加密API密钥
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: 原始API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加密后的密钥字符串
|
||||||
|
"""
|
||||||
|
if not api_key:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
encrypted = fernet.encrypt(api_key.encode())
|
||||||
|
return encrypted.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API密钥加密失败: {e}")
|
||||||
|
raise ValueError("加密失败")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_api_key(encrypted_key: str) -> str:
|
||||||
|
"""
|
||||||
|
解密API密钥
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_key: 加密的密钥字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解密后的原始API密钥
|
||||||
|
"""
|
||||||
|
if not encrypted_key:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
fernet = _get_fernet()
|
||||||
|
decrypted = fernet.decrypt(encrypted_key.encode())
|
||||||
|
return decrypted.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API密钥解密失败: {e}")
|
||||||
|
raise ValueError("解密失败,密钥可能已损坏或被篡改")
|
||||||
|
|
||||||
|
|
||||||
|
def mask_api_key(api_key: str, visible_chars: int = 4) -> str:
|
||||||
|
"""
|
||||||
|
掩码API密钥,用于安全显示
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: 原始API密钥
|
||||||
|
visible_chars: 末尾可见字符数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
掩码后的密钥 (如: ****abc1)
|
||||||
|
"""
|
||||||
|
if not api_key:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if len(api_key) <= visible_chars:
|
||||||
|
return "*" * len(api_key)
|
||||||
|
|
||||||
|
return "*" * (len(api_key) - visible_chars) + api_key[-visible_chars:]
|
||||||
135
backend/utils/proxy_handler.py
Normal file
135
backend/utils/proxy_handler.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
代理处理模块
|
||||||
|
处理HTTP代理配置
|
||||||
|
"""
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_proxy_dict(
|
||||||
|
use_proxy: bool,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
获取代理配置字典
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_proxy: 是否使用代理
|
||||||
|
proxy_config: 代理配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
代理配置字典或None
|
||||||
|
"""
|
||||||
|
if not use_proxy:
|
||||||
|
return None
|
||||||
|
|
||||||
|
proxies = {}
|
||||||
|
|
||||||
|
if proxy_config:
|
||||||
|
http_proxy = proxy_config.get("http_proxy")
|
||||||
|
https_proxy = proxy_config.get("https_proxy")
|
||||||
|
else:
|
||||||
|
# 使用全局默认代理
|
||||||
|
http_proxy = settings.DEFAULT_HTTP_PROXY
|
||||||
|
https_proxy = settings.DEFAULT_HTTPS_PROXY
|
||||||
|
|
||||||
|
if http_proxy:
|
||||||
|
proxies["http://"] = http_proxy
|
||||||
|
if https_proxy:
|
||||||
|
proxies["https://"] = https_proxy
|
||||||
|
|
||||||
|
return proxies if proxies else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_http_client(
|
||||||
|
use_proxy: bool = False,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
**kwargs
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
获取配置好的HTTP异步客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_proxy: 是否使用代理
|
||||||
|
proxy_config: 代理配置
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
**kwargs: 其他httpx参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的httpx.AsyncClient实例
|
||||||
|
"""
|
||||||
|
proxies = get_proxy_dict(use_proxy, proxy_config)
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"timeout": httpx.Timeout(timeout),
|
||||||
|
"follow_redirects": True,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxies:
|
||||||
|
client_kwargs["proxies"] = proxies
|
||||||
|
logger.debug(f"HTTP客户端使用代理: {proxies}")
|
||||||
|
|
||||||
|
return httpx.AsyncClient(**client_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_proxy_connection(
|
||||||
|
proxy_config: Dict[str, Any],
|
||||||
|
test_url: str = "https://www.google.com"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试代理连接是否可用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置
|
||||||
|
test_url: 测试URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果字典,包含 success, message, latency_ms
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with get_http_client(
|
||||||
|
use_proxy=True,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
timeout=10
|
||||||
|
) as client:
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
response = await client.get(test_url)
|
||||||
|
latency = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "代理连接正常",
|
||||||
|
"latency_ms": round(latency, 2)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"代理返回状态码: {response.status_code}",
|
||||||
|
"latency_ms": round(latency, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
except httpx.ProxyError as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"代理连接失败: {str(e)}",
|
||||||
|
"latency_ms": None
|
||||||
|
}
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "代理连接超时",
|
||||||
|
"latency_ms": None
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"连接错误: {str(e)}",
|
||||||
|
"latency_ms": None
|
||||||
|
}
|
||||||
233
backend/utils/rate_limiter.py
Normal file
233
backend/utils/rate_limiter.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""
|
||||||
|
速率限制器模块
|
||||||
|
使用令牌桶算法控制请求频率
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenBucket:
|
||||||
|
"""令牌桶"""
|
||||||
|
capacity: int # 桶容量
|
||||||
|
tokens: float = field(init=False) # 当前令牌数
|
||||||
|
refill_rate: float # 每秒填充速率
|
||||||
|
last_refill: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.tokens = float(self.capacity)
|
||||||
|
|
||||||
|
def _refill(self) -> None:
|
||||||
|
"""填充令牌"""
|
||||||
|
now = time.time()
|
||||||
|
elapsed = now - self.last_refill
|
||||||
|
self.tokens = min(
|
||||||
|
self.capacity,
|
||||||
|
self.tokens + elapsed * self.refill_rate
|
||||||
|
)
|
||||||
|
self.last_refill = now
|
||||||
|
|
||||||
|
def consume(self, tokens: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
尝试消费令牌
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: 要消费的令牌数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否消费成功
|
||||||
|
"""
|
||||||
|
self._refill()
|
||||||
|
if self.tokens >= tokens:
|
||||||
|
self.tokens -= tokens
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def wait_time(self, tokens: int = 1) -> float:
|
||||||
|
"""
|
||||||
|
计算需要等待的时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: 需要的令牌数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
需要等待的秒数
|
||||||
|
"""
|
||||||
|
self._refill()
|
||||||
|
if self.tokens >= tokens:
|
||||||
|
return 0.0
|
||||||
|
needed = tokens - self.tokens
|
||||||
|
return needed / self.refill_rate
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
速率限制器
|
||||||
|
管理多个提供商的速率限制
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._buckets: Dict[str, TokenBucket] = {}
|
||||||
|
self._locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
requests_per_minute: int = 60,
|
||||||
|
tokens_per_minute: int = 100000
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
注册提供商的速率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 提供商ID
|
||||||
|
requests_per_minute: 每分钟请求数
|
||||||
|
tokens_per_minute: 每分钟token数
|
||||||
|
"""
|
||||||
|
# 请求限制桶
|
||||||
|
self._buckets[f"{provider_id}:requests"] = TokenBucket(
|
||||||
|
capacity=requests_per_minute,
|
||||||
|
refill_rate=requests_per_minute / 60.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token限制桶
|
||||||
|
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
|
||||||
|
capacity=tokens_per_minute,
|
||||||
|
refill_rate=tokens_per_minute / 60.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建锁
|
||||||
|
self._locks[provider_id] = asyncio.Lock()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"注册速率限制: {provider_id} - "
|
||||||
|
f"{requests_per_minute}请求/分钟, "
|
||||||
|
f"{tokens_per_minute}tokens/分钟"
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister(self, provider_id: str) -> None:
|
||||||
|
"""
|
||||||
|
取消注册提供商的速率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 提供商ID
|
||||||
|
"""
|
||||||
|
self._buckets.pop(f"{provider_id}:requests", None)
|
||||||
|
self._buckets.pop(f"{provider_id}:tokens", None)
|
||||||
|
self._locks.pop(provider_id, None)
|
||||||
|
|
||||||
|
async def acquire(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
estimated_tokens: int = 1
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
获取请求许可(非阻塞)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 提供商ID
|
||||||
|
estimated_tokens: 预估token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否获取成功
|
||||||
|
"""
|
||||||
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||||
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||||
|
|
||||||
|
if not request_bucket or not token_bucket:
|
||||||
|
# 未注册,默认允许
|
||||||
|
return True
|
||||||
|
|
||||||
|
lock = self._locks.get(provider_id)
|
||||||
|
if lock:
|
||||||
|
async with lock:
|
||||||
|
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def acquire_wait(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
estimated_tokens: int = 1,
|
||||||
|
max_wait: float = 60.0
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
获取请求许可(阻塞等待)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 提供商ID
|
||||||
|
estimated_tokens: 预估token数
|
||||||
|
max_wait: 最大等待时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否获取成功
|
||||||
|
"""
|
||||||
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||||
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||||
|
|
||||||
|
if not request_bucket or not token_bucket:
|
||||||
|
return True
|
||||||
|
|
||||||
|
lock = self._locks.get(provider_id)
|
||||||
|
if not lock:
|
||||||
|
return True
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
async with lock:
|
||||||
|
# 计算需要等待的时间
|
||||||
|
request_wait = request_bucket.wait_time(1)
|
||||||
|
token_wait = token_bucket.wait_time(estimated_tokens)
|
||||||
|
wait_time = max(request_wait, token_wait)
|
||||||
|
|
||||||
|
if wait_time == 0:
|
||||||
|
request_bucket.consume(1)
|
||||||
|
token_bucket.consume(estimated_tokens)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查是否超时
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed + wait_time > max_wait:
|
||||||
|
logger.warning(
|
||||||
|
f"速率限制等待超时: {provider_id}, "
|
||||||
|
f"需要等待{wait_time:.2f}秒"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 在锁外等待
|
||||||
|
await asyncio.sleep(min(wait_time, 1.0))
|
||||||
|
|
||||||
|
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
|
||||||
|
"""
|
||||||
|
获取提供商的速率限制状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_id: 提供商ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态字典
|
||||||
|
"""
|
||||||
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||||
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||||
|
|
||||||
|
if not request_bucket or not token_bucket:
|
||||||
|
return None
|
||||||
|
|
||||||
|
request_bucket._refill()
|
||||||
|
token_bucket._refill()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"requests_remaining": int(request_bucket.tokens),
|
||||||
|
"requests_capacity": request_bucket.capacity,
|
||||||
|
"tokens_remaining": int(token_bucket.tokens),
|
||||||
|
"tokens_capacity": token_bucket.capacity
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局速率限制器实例
|
||||||
|
rate_limiter = RateLimiter()
|
||||||
70
docker-compose.yml
Normal file
70
docker-compose.yml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
# MongoDB数据库
|
||||||
|
mongodb:
|
||||||
|
image: mongo:7
|
||||||
|
container_name: ai-chatroom-mongodb
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "27017:27017"
|
||||||
|
volumes:
|
||||||
|
- mongodb_data:/data/db
|
||||||
|
environment:
|
||||||
|
MONGO_INITDB_ROOT_USERNAME: admin
|
||||||
|
MONGO_INITDB_ROOT_PASSWORD: ${MONGO_PASSWORD:-chatroom123}
|
||||||
|
networks:
|
||||||
|
- chatroom-network
|
||||||
|
healthcheck:
|
||||||
|
test: echo 'db.runCommand("ping").ok' | mongosh localhost:27017/test --quiet
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
# 后端服务
|
||||||
|
backend:
|
||||||
|
build:
|
||||||
|
context: ./backend
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: ai-chatroom-backend
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
depends_on:
|
||||||
|
mongodb:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
- MONGODB_URL=mongodb://admin:${MONGO_PASSWORD:-chatroom123}@mongodb:27017
|
||||||
|
- MONGODB_DB=ai_chatroom
|
||||||
|
- HOST=0.0.0.0
|
||||||
|
- PORT=8000
|
||||||
|
- DEBUG=${DEBUG:-false}
|
||||||
|
- SECRET_KEY=${SECRET_KEY:-change-this-in-production}
|
||||||
|
- ENCRYPTION_KEY=${ENCRYPTION_KEY:-your-32-byte-encryption-key-here}
|
||||||
|
volumes:
|
||||||
|
- ./backend:/app
|
||||||
|
networks:
|
||||||
|
- chatroom-network
|
||||||
|
|
||||||
|
# 前端服务
|
||||||
|
frontend:
|
||||||
|
build:
|
||||||
|
context: ./frontend
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: ai-chatroom-frontend
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "3000:80"
|
||||||
|
depends_on:
|
||||||
|
- backend
|
||||||
|
environment:
|
||||||
|
- VITE_API_URL=http://backend:8000
|
||||||
|
networks:
|
||||||
|
- chatroom-network
|
||||||
|
|
||||||
|
networks:
|
||||||
|
chatroom-network:
|
||||||
|
driver: bridge
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
mongodb_data:
|
||||||
31
frontend/Dockerfile
Normal file
31
frontend/Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# AI聊天室前端 Dockerfile
|
||||||
|
# 构建阶段
|
||||||
|
FROM node:20-alpine as builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 复制依赖文件
|
||||||
|
COPY package.json .
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
RUN npm install
|
||||||
|
|
||||||
|
# 复制源代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# 构建
|
||||||
|
RUN npm run build
|
||||||
|
|
||||||
|
# 生产阶段
|
||||||
|
FROM nginx:alpine
|
||||||
|
|
||||||
|
# 复制构建产物
|
||||||
|
COPY --from=builder /app/dist /usr/share/nginx/html
|
||||||
|
|
||||||
|
# 复制nginx配置
|
||||||
|
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||||
|
|
||||||
|
# 暴露端口
|
||||||
|
EXPOSE 80
|
||||||
|
|
||||||
|
CMD ["nginx", "-g", "daemon off;"]
|
||||||
13
frontend/index.html
Normal file
13
frontend/index.html
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>AI聊天室 - 多Agent协作讨论平台</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
46
frontend/nginx.conf
Normal file
46
frontend/nginx.conf
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name localhost;
|
||||||
|
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
index index.html;
|
||||||
|
|
||||||
|
# 前端路由支持
|
||||||
|
location / {
|
||||||
|
try_files $uri $uri/ /index.html;
|
||||||
|
}
|
||||||
|
|
||||||
|
# API代理
|
||||||
|
location /api {
|
||||||
|
proxy_pass http://backend:8000;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection "upgrade";
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
proxy_read_timeout 86400;
|
||||||
|
}
|
||||||
|
|
||||||
|
# WebSocket代理
|
||||||
|
location /ws {
|
||||||
|
proxy_pass http://backend:8000;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection "upgrade";
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_read_timeout 86400;
|
||||||
|
}
|
||||||
|
|
||||||
|
# 静态资源缓存
|
||||||
|
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg)$ {
|
||||||
|
expires 30d;
|
||||||
|
add_header Cache-Control "public, immutable";
|
||||||
|
}
|
||||||
|
|
||||||
|
# Gzip压缩
|
||||||
|
gzip on;
|
||||||
|
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
|
||||||
|
gzip_min_length 1000;
|
||||||
|
}
|
||||||
3192
frontend/package-lock.json
generated
Normal file
3192
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
28
frontend/package.json
Normal file
28
frontend/package.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"name": "ai-chatroom-frontend",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"private": true,
|
||||||
|
"dependencies": {
|
||||||
|
"@ant-design/icons": "^5.2.6",
|
||||||
|
"antd": "^5.13.0",
|
||||||
|
"axios": "^1.6.5",
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"react-router-dom": "^6.21.1",
|
||||||
|
"zustand": "^4.4.7",
|
||||||
|
"socket.io-client": "^4.7.4",
|
||||||
|
"dayjs": "^1.11.10"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.2.47",
|
||||||
|
"@types/react-dom": "^18.2.18",
|
||||||
|
"@vitejs/plugin-react": "^4.2.1",
|
||||||
|
"typescript": "^5.3.3",
|
||||||
|
"vite": "^5.0.11"
|
||||||
|
},
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
}
|
||||||
|
}
|
||||||
39
frontend/src/App.tsx
Normal file
39
frontend/src/App.tsx
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
/**
|
||||||
|
* 应用主组件
|
||||||
|
* 定义路由和布局
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom'
|
||||||
|
import { Layout } from 'antd'
|
||||||
|
import Sidebar from './components/Sidebar'
|
||||||
|
import Dashboard from './pages/Dashboard'
|
||||||
|
import ProviderConfig from './pages/ProviderConfig'
|
||||||
|
import AgentManagement from './pages/AgentManagement'
|
||||||
|
import ChatRoom from './pages/ChatRoom'
|
||||||
|
import DiscussionHistory from './pages/DiscussionHistory'
|
||||||
|
|
||||||
|
const { Content } = Layout
|
||||||
|
|
||||||
|
const App: React.FC = () => {
|
||||||
|
return (
|
||||||
|
<BrowserRouter>
|
||||||
|
<Layout style={{ minHeight: '100vh' }}>
|
||||||
|
<Sidebar />
|
||||||
|
<Layout>
|
||||||
|
<Content style={{ margin: '24px', overflow: 'auto' }}>
|
||||||
|
<Routes>
|
||||||
|
<Route path="/" element={<Navigate to="/dashboard\" replace />} />
|
||||||
|
<Route path="/dashboard" element={<Dashboard />} />
|
||||||
|
<Route path="/providers" element={<ProviderConfig />} />
|
||||||
|
<Route path="/agents" element={<AgentManagement />} />
|
||||||
|
<Route path="/chatroom/:roomId?" element={<ChatRoom />} />
|
||||||
|
<Route path="/history" element={<DiscussionHistory />} />
|
||||||
|
</Routes>
|
||||||
|
</Content>
|
||||||
|
</Layout>
|
||||||
|
</Layout>
|
||||||
|
</BrowserRouter>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App
|
||||||
122
frontend/src/components/AgentCard.tsx
Normal file
122
frontend/src/components/AgentCard.tsx
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
/**
|
||||||
|
* Agent卡片组件
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { Card, Avatar, Tag, Switch, Typography, Space, Button, Tooltip } from 'antd'
|
||||||
|
import {
|
||||||
|
RobotOutlined,
|
||||||
|
EditOutlined,
|
||||||
|
DeleteOutlined,
|
||||||
|
CopyOutlined,
|
||||||
|
PlayCircleOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import type { Agent } from '../types'
|
||||||
|
|
||||||
|
const { Text, Paragraph } = Typography
|
||||||
|
|
||||||
|
interface AgentCardProps {
|
||||||
|
agent: Agent
|
||||||
|
onEdit?: (agent: Agent) => void
|
||||||
|
onDelete?: (agent: Agent) => void
|
||||||
|
onDuplicate?: (agent: Agent) => void
|
||||||
|
onTest?: (agent: Agent) => void
|
||||||
|
onToggleEnabled?: (agent: Agent, enabled: boolean) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const AgentCard: React.FC<AgentCardProps> = ({
|
||||||
|
agent,
|
||||||
|
onEdit,
|
||||||
|
onDelete,
|
||||||
|
onDuplicate,
|
||||||
|
onTest,
|
||||||
|
onToggleEnabled
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<Card
|
||||||
|
hoverable
|
||||||
|
style={{
|
||||||
|
borderTop: `3px solid ${agent.color}`,
|
||||||
|
opacity: agent.enabled ? 1 : 0.6
|
||||||
|
}}
|
||||||
|
actions={[
|
||||||
|
<Tooltip title="测试对话" key="test">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
icon={<PlayCircleOutlined />}
|
||||||
|
onClick={() => onTest?.(agent)}
|
||||||
|
/>
|
||||||
|
</Tooltip>,
|
||||||
|
<Tooltip title="复制" key="copy">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
icon={<CopyOutlined />}
|
||||||
|
onClick={() => onDuplicate?.(agent)}
|
||||||
|
/>
|
||||||
|
</Tooltip>,
|
||||||
|
<Tooltip title="编辑" key="edit">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
icon={<EditOutlined />}
|
||||||
|
onClick={() => onEdit?.(agent)}
|
||||||
|
/>
|
||||||
|
</Tooltip>,
|
||||||
|
<Tooltip title="删除" key="delete">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
danger
|
||||||
|
icon={<DeleteOutlined />}
|
||||||
|
onClick={() => onDelete?.(agent)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Card.Meta
|
||||||
|
avatar={
|
||||||
|
<Avatar
|
||||||
|
size={48}
|
||||||
|
icon={<RobotOutlined />}
|
||||||
|
style={{ backgroundColor: agent.color }}
|
||||||
|
src={agent.avatar}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>{agent.name}</span>
|
||||||
|
<Switch
|
||||||
|
size="small"
|
||||||
|
checked={agent.enabled}
|
||||||
|
onChange={(checked) => onToggleEnabled?.(agent, checked)}
|
||||||
|
/>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
description={
|
||||||
|
<div>
|
||||||
|
<Text type="secondary">{agent.role}</Text>
|
||||||
|
|
||||||
|
<div style={{ marginTop: 8 }}>
|
||||||
|
{agent.capabilities.memory_enabled && (
|
||||||
|
<Tag color="blue">记忆</Tag>
|
||||||
|
)}
|
||||||
|
{agent.capabilities.multimodal && (
|
||||||
|
<Tag color="purple">多模态</Tag>
|
||||||
|
)}
|
||||||
|
{agent.capabilities.mcp_tools.length > 0 && (
|
||||||
|
<Tag color="green">MCP工具</Tag>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Paragraph
|
||||||
|
type="secondary"
|
||||||
|
ellipsis={{ rows: 2 }}
|
||||||
|
style={{ marginTop: 8, marginBottom: 0, fontSize: 12 }}
|
||||||
|
>
|
||||||
|
{agent.system_prompt}
|
||||||
|
</Paragraph>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default AgentCard
|
||||||
97
frontend/src/components/MessageBubble.tsx
Normal file
97
frontend/src/components/MessageBubble.tsx
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
/**
|
||||||
|
* 消息气泡组件
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { Avatar, Typography } from 'antd'
|
||||||
|
import { RobotOutlined, InfoCircleOutlined } from '@ant-design/icons'
|
||||||
|
import type { Message, Agent } from '../types'
|
||||||
|
import dayjs from 'dayjs'
|
||||||
|
|
||||||
|
const { Text, Paragraph } = Typography
|
||||||
|
|
||||||
|
interface MessageBubbleProps {
|
||||||
|
message: Message
|
||||||
|
agent?: Agent
|
||||||
|
}
|
||||||
|
|
||||||
|
const MessageBubble: React.FC<MessageBubbleProps> = ({ message, agent }) => {
|
||||||
|
const isSystem = message.message_type === 'system' || !message.agent_id
|
||||||
|
|
||||||
|
if (isSystem) {
|
||||||
|
return (
|
||||||
|
<div className="message-fade-in" style={{
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'center',
|
||||||
|
margin: '16px 0'
|
||||||
|
}}>
|
||||||
|
<div style={{
|
||||||
|
background: '#f0f0f0',
|
||||||
|
padding: '8px 16px',
|
||||||
|
borderRadius: 8,
|
||||||
|
maxWidth: '80%'
|
||||||
|
}}>
|
||||||
|
<InfoCircleOutlined style={{ marginRight: 8, color: '#666' }} />
|
||||||
|
<Text type="secondary" style={{ whiteSpace: 'pre-wrap' }}>
|
||||||
|
{message.content}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentColor = agent?.color || '#1890ff'
|
||||||
|
const agentName = agent?.name || message.agent_id
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="message-fade-in" style={{
|
||||||
|
display: 'flex',
|
||||||
|
gap: 12,
|
||||||
|
margin: '12px 0',
|
||||||
|
padding: '8px 12px'
|
||||||
|
}}>
|
||||||
|
<Avatar
|
||||||
|
size={40}
|
||||||
|
icon={<RobotOutlined />}
|
||||||
|
style={{
|
||||||
|
backgroundColor: agentColor,
|
||||||
|
flexShrink: 0
|
||||||
|
}}
|
||||||
|
src={agent?.avatar}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div style={{ flex: 1, minWidth: 0 }}>
|
||||||
|
<div style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 8,
|
||||||
|
marginBottom: 4
|
||||||
|
}}>
|
||||||
|
<Text strong style={{ color: agentColor }}>
|
||||||
|
{agentName}
|
||||||
|
</Text>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
第{message.round}轮 · {dayjs(message.created_at).format('HH:mm:ss')}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style={{
|
||||||
|
background: '#fff',
|
||||||
|
padding: '12px 16px',
|
||||||
|
borderRadius: '0 12px 12px 12px',
|
||||||
|
boxShadow: '0 1px 2px rgba(0,0,0,0.1)',
|
||||||
|
borderLeft: `3px solid ${agentColor}`
|
||||||
|
}}>
|
||||||
|
<Paragraph style={{
|
||||||
|
margin: 0,
|
||||||
|
whiteSpace: 'pre-wrap',
|
||||||
|
wordBreak: 'break-word'
|
||||||
|
}}>
|
||||||
|
{message.content}
|
||||||
|
</Paragraph>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MessageBubble
|
||||||
139
frontend/src/components/ProviderCard.tsx
Normal file
139
frontend/src/components/ProviderCard.tsx
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
/**
|
||||||
|
* AI接口卡片组件
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { Card, Tag, Switch, Typography, Space, Button, Tooltip } from 'antd'
|
||||||
|
import {
|
||||||
|
ApiOutlined,
|
||||||
|
EditOutlined,
|
||||||
|
DeleteOutlined,
|
||||||
|
CheckCircleOutlined,
|
||||||
|
CloudOutlined,
|
||||||
|
DesktopOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import type { AIProvider } from '../types'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
|
||||||
|
// 提供商类型图标和颜色
|
||||||
|
const PROVIDER_CONFIG: Record<string, { color: string; label: string; local?: boolean }> = {
|
||||||
|
minimax: { color: '#1890ff', label: 'MiniMax' },
|
||||||
|
zhipu: { color: '#52c41a', label: '智谱AI' },
|
||||||
|
openrouter: { color: '#722ed1', label: 'OpenRouter' },
|
||||||
|
kimi: { color: '#eb2f96', label: 'Kimi' },
|
||||||
|
deepseek: { color: '#13c2c2', label: 'DeepSeek' },
|
||||||
|
gemini: { color: '#4285f4', label: 'Gemini' },
|
||||||
|
ollama: { color: '#fa8c16', label: 'Ollama', local: true },
|
||||||
|
llmstudio: { color: '#a0d911', label: 'LLM Studio', local: true }
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ProviderCardProps {
|
||||||
|
provider: AIProvider
|
||||||
|
onEdit?: (provider: AIProvider) => void
|
||||||
|
onDelete?: (provider: AIProvider) => void
|
||||||
|
onTest?: (provider: AIProvider) => void
|
||||||
|
onToggleEnabled?: (provider: AIProvider, enabled: boolean) => void
|
||||||
|
testing?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const ProviderCard: React.FC<ProviderCardProps> = ({
|
||||||
|
provider,
|
||||||
|
onEdit,
|
||||||
|
onDelete,
|
||||||
|
onTest,
|
||||||
|
onToggleEnabled,
|
||||||
|
testing = false
|
||||||
|
}) => {
|
||||||
|
const config = PROVIDER_CONFIG[provider.provider_type] || {
|
||||||
|
color: '#666',
|
||||||
|
label: provider.provider_type
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card
|
||||||
|
hoverable
|
||||||
|
style={{
|
||||||
|
borderTop: `3px solid ${config.color}`,
|
||||||
|
opacity: provider.enabled ? 1 : 0.6
|
||||||
|
}}
|
||||||
|
actions={[
|
||||||
|
<Tooltip title="测试连接" key="test">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
icon={<CheckCircleOutlined />}
|
||||||
|
onClick={() => onTest?.(provider)}
|
||||||
|
loading={testing}
|
||||||
|
/>
|
||||||
|
</Tooltip>,
|
||||||
|
<Tooltip title="编辑" key="edit">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
icon={<EditOutlined />}
|
||||||
|
onClick={() => onEdit?.(provider)}
|
||||||
|
/>
|
||||||
|
</Tooltip>,
|
||||||
|
<Tooltip title="删除" key="delete">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
danger
|
||||||
|
icon={<DeleteOutlined />}
|
||||||
|
onClick={() => onDelete?.(provider)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Card.Meta
|
||||||
|
avatar={
|
||||||
|
<div style={{
|
||||||
|
width: 48,
|
||||||
|
height: 48,
|
||||||
|
borderRadius: 8,
|
||||||
|
background: config.color,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
color: '#fff',
|
||||||
|
fontSize: 24
|
||||||
|
}}>
|
||||||
|
{config.local ? <DesktopOutlined /> : <CloudOutlined />}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>{provider.name}</span>
|
||||||
|
<Switch
|
||||||
|
size="small"
|
||||||
|
checked={provider.enabled}
|
||||||
|
onChange={(checked) => onToggleEnabled?.(provider, checked)}
|
||||||
|
/>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
description={
|
||||||
|
<div>
|
||||||
|
<Space size={4} wrap>
|
||||||
|
<Tag color={config.color}>{config.label}</Tag>
|
||||||
|
{provider.use_proxy && <Tag color="orange">代理</Tag>}
|
||||||
|
{config.local && <Tag color="green">本地</Tag>}
|
||||||
|
</Space>
|
||||||
|
|
||||||
|
<div style={{ marginTop: 8 }}>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
模型: {provider.model}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{provider.api_key_masked && (
|
||||||
|
<div style={{ marginTop: 4 }}>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
密钥: {provider.api_key_masked}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ProviderCard
|
||||||
92
frontend/src/components/Sidebar.tsx
Normal file
92
frontend/src/components/Sidebar.tsx
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
/**
|
||||||
|
* 侧边栏导航组件
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { useNavigate, useLocation } from 'react-router-dom'
|
||||||
|
import { Layout, Menu } from 'antd'
|
||||||
|
import {
|
||||||
|
DashboardOutlined,
|
||||||
|
ApiOutlined,
|
||||||
|
RobotOutlined,
|
||||||
|
MessageOutlined,
|
||||||
|
HistoryOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
|
||||||
|
const { Sider } = Layout
|
||||||
|
|
||||||
|
const menuItems = [
|
||||||
|
{
|
||||||
|
key: '/dashboard',
|
||||||
|
icon: <DashboardOutlined />,
|
||||||
|
label: '控制台'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: '/providers',
|
||||||
|
icon: <ApiOutlined />,
|
||||||
|
label: 'AI接口配置'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: '/agents',
|
||||||
|
icon: <RobotOutlined />,
|
||||||
|
label: 'Agent管理'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: '/chatroom',
|
||||||
|
icon: <MessageOutlined />,
|
||||||
|
label: '聊天室'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: '/history',
|
||||||
|
icon: <HistoryOutlined />,
|
||||||
|
label: '讨论历史'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const Sidebar: React.FC = () => {
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const location = useLocation()
|
||||||
|
|
||||||
|
// 获取当前选中的菜单项
|
||||||
|
const selectedKey = menuItems.find(
|
||||||
|
item => location.pathname.startsWith(item.key)
|
||||||
|
)?.key || '/dashboard'
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Sider
|
||||||
|
theme="dark"
|
||||||
|
width={200}
|
||||||
|
style={{
|
||||||
|
overflow: 'auto',
|
||||||
|
height: '100vh',
|
||||||
|
position: 'fixed',
|
||||||
|
left: 0,
|
||||||
|
top: 0,
|
||||||
|
bottom: 0
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div style={{
|
||||||
|
height: 64,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
color: '#fff',
|
||||||
|
fontSize: 18,
|
||||||
|
fontWeight: 'bold',
|
||||||
|
borderBottom: '1px solid rgba(255,255,255,0.1)'
|
||||||
|
}}>
|
||||||
|
AI聊天室
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Menu
|
||||||
|
theme="dark"
|
||||||
|
mode="inline"
|
||||||
|
selectedKeys={[selectedKey]}
|
||||||
|
items={menuItems}
|
||||||
|
onClick={({ key }) => navigate(key)}
|
||||||
|
style={{ borderRight: 0 }}
|
||||||
|
/>
|
||||||
|
</Sider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Sidebar
|
||||||
74
frontend/src/components/TypingIndicator.tsx
Normal file
74
frontend/src/components/TypingIndicator.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
/**
|
||||||
|
* 输入状态指示器
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import { Avatar, Typography } from 'antd'
|
||||||
|
import { RobotOutlined } from '@ant-design/icons'
|
||||||
|
import type { Agent } from '../types'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
|
||||||
|
interface TypingIndicatorProps {
|
||||||
|
agents: Agent[]
|
||||||
|
}
|
||||||
|
|
||||||
|
const TypingIndicator: React.FC<TypingIndicatorProps> = ({ agents }) => {
|
||||||
|
if (agents.length === 0) return null
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 12,
|
||||||
|
padding: '8px 12px',
|
||||||
|
marginTop: 8
|
||||||
|
}}>
|
||||||
|
<div style={{ display: 'flex' }}>
|
||||||
|
{agents.slice(0, 3).map((agent, index) => (
|
||||||
|
<Avatar
|
||||||
|
key={agent.agent_id}
|
||||||
|
size={32}
|
||||||
|
icon={<RobotOutlined />}
|
||||||
|
style={{
|
||||||
|
backgroundColor: agent.color,
|
||||||
|
marginLeft: index > 0 ? -8 : 0,
|
||||||
|
border: '2px solid #fff'
|
||||||
|
}}
|
||||||
|
src={agent.avatar}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style={{ display: 'flex', alignItems: 'center', gap: 4 }}>
|
||||||
|
<Text type="secondary">
|
||||||
|
{agents.length === 1
|
||||||
|
? `${agents[0].name} 正在思考`
|
||||||
|
: `${agents.length} 位Agent正在思考`
|
||||||
|
}
|
||||||
|
</Text>
|
||||||
|
<span style={{ display: 'flex', gap: 2 }}>
|
||||||
|
<span className="typing-dot" style={{
|
||||||
|
width: 4,
|
||||||
|
height: 4,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: '#999'
|
||||||
|
}} />
|
||||||
|
<span className="typing-dot" style={{
|
||||||
|
width: 4,
|
||||||
|
height: 4,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: '#999'
|
||||||
|
}} />
|
||||||
|
<span className="typing-dot" style={{
|
||||||
|
width: 4,
|
||||||
|
height: 4,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: '#999'
|
||||||
|
}} />
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default TypingIndicator
|
||||||
73
frontend/src/index.css
Normal file
73
frontend/src/index.css
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
html, body, #root {
|
||||||
|
height: 100%;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||||
|
background-color: #f0f2f5;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 自定义滚动条 */
|
||||||
|
::-webkit-scrollbar {
|
||||||
|
width: 6px;
|
||||||
|
height: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-track {
|
||||||
|
background: #f1f1f1;
|
||||||
|
border-radius: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb {
|
||||||
|
background: #c1c1c1;
|
||||||
|
border-radius: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: #a8a8a8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 消息动画 */
|
||||||
|
@keyframes fadeIn {
|
||||||
|
from {
|
||||||
|
opacity: 0;
|
||||||
|
transform: translateY(10px);
|
||||||
|
}
|
||||||
|
to {
|
||||||
|
opacity: 1;
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-fade-in {
|
||||||
|
animation: fadeIn 0.3s ease-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 打字动画 */
|
||||||
|
@keyframes typing {
|
||||||
|
0%, 60%, 100% {
|
||||||
|
opacity: 0.3;
|
||||||
|
}
|
||||||
|
30% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.typing-dot {
|
||||||
|
animation: typing 1.4s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.typing-dot:nth-child(2) {
|
||||||
|
animation-delay: 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.typing-dot:nth-child(3) {
|
||||||
|
animation-delay: 0.4s;
|
||||||
|
}
|
||||||
17
frontend/src/main.tsx
Normal file
17
frontend/src/main.tsx
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
/**
|
||||||
|
* 应用入口
|
||||||
|
*/
|
||||||
|
import React from 'react'
|
||||||
|
import ReactDOM from 'react-dom/client'
|
||||||
|
import { ConfigProvider } from 'antd'
|
||||||
|
import zhCN from 'antd/locale/zh_CN'
|
||||||
|
import App from './App'
|
||||||
|
import './index.css'
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<ConfigProvider locale={zhCN}>
|
||||||
|
<App />
|
||||||
|
</ConfigProvider>
|
||||||
|
</React.StrictMode>
|
||||||
|
)
|
||||||
471
frontend/src/pages/AgentManagement.tsx
Normal file
471
frontend/src/pages/AgentManagement.tsx
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
/**
|
||||||
|
* Agent管理页面
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import {
|
||||||
|
Row, Col, Button, Modal, Form, Input, Select, Switch, Slider,
|
||||||
|
Typography, message, Spin, Empty, Tabs, Card, Space, Tooltip
|
||||||
|
} from 'antd'
|
||||||
|
import { PlusOutlined, RobotOutlined, BulbOutlined, LoadingOutlined } from '@ant-design/icons'
|
||||||
|
import { useAgentStore } from '../stores/agentStore'
|
||||||
|
import { useProviderStore } from '../stores/providerStore'
|
||||||
|
import AgentCard from '../components/AgentCard'
|
||||||
|
import type { Agent } from '../types'
|
||||||
|
|
||||||
|
const { Title, Paragraph, Text } = Typography
|
||||||
|
const { TextArea } = Input
|
||||||
|
|
||||||
|
const AgentManagement: React.FC = () => {
|
||||||
|
const {
|
||||||
|
agents, templates, loading,
|
||||||
|
fetchAgents, fetchTemplates,
|
||||||
|
createAgent, updateAgent, deleteAgent,
|
||||||
|
testAgent, duplicateAgent, createFromTemplate, generatePrompt
|
||||||
|
} = useAgentStore()
|
||||||
|
const { providers, fetchProviders } = useProviderStore()
|
||||||
|
|
||||||
|
const [modalVisible, setModalVisible] = useState(false)
|
||||||
|
const [editingAgent, setEditingAgent] = useState<Agent | null>(null)
|
||||||
|
const [testModalVisible, setTestModalVisible] = useState(false)
|
||||||
|
const [testingAgent, setTestingAgent] = useState<Agent | null>(null)
|
||||||
|
const [testResult, setTestResult] = useState<string>('')
|
||||||
|
const [testLoading, setTestLoading] = useState(false)
|
||||||
|
const [generatingPrompt, setGeneratingPrompt] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchAgents()
|
||||||
|
fetchTemplates()
|
||||||
|
fetchProviders()
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const enabledProviders = providers.filter(p => p.enabled)
|
||||||
|
|
||||||
|
// 打开创建/编辑弹窗
|
||||||
|
const openModal = (agent?: Agent) => {
|
||||||
|
setEditingAgent(agent || null)
|
||||||
|
if (agent) {
|
||||||
|
form.setFieldsValue({
|
||||||
|
name: agent.name,
|
||||||
|
role: agent.role,
|
||||||
|
system_prompt: agent.system_prompt,
|
||||||
|
provider_id: agent.provider_id,
|
||||||
|
temperature: agent.temperature,
|
||||||
|
max_tokens: agent.max_tokens,
|
||||||
|
memory_enabled: agent.capabilities.memory_enabled,
|
||||||
|
multimodal: agent.capabilities.multimodal,
|
||||||
|
speak_threshold: agent.behavior.speak_threshold,
|
||||||
|
max_speak_per_round: agent.behavior.max_speak_per_round,
|
||||||
|
color: agent.color
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
form.resetFields()
|
||||||
|
form.setFieldsValue({
|
||||||
|
temperature: 0.7,
|
||||||
|
max_tokens: 2000,
|
||||||
|
speak_threshold: 0.5,
|
||||||
|
max_speak_per_round: 2,
|
||||||
|
color: '#1890ff'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
setModalVisible(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存
|
||||||
|
const handleSave = async () => {
|
||||||
|
try {
|
||||||
|
const values = await form.validateFields()
|
||||||
|
|
||||||
|
const data = {
|
||||||
|
name: values.name,
|
||||||
|
role: values.role,
|
||||||
|
system_prompt: values.system_prompt,
|
||||||
|
provider_id: values.provider_id,
|
||||||
|
temperature: values.temperature,
|
||||||
|
max_tokens: values.max_tokens,
|
||||||
|
capabilities: {
|
||||||
|
memory_enabled: values.memory_enabled || false,
|
||||||
|
multimodal: values.multimodal || false,
|
||||||
|
mcp_tools: [],
|
||||||
|
skills: []
|
||||||
|
},
|
||||||
|
behavior: {
|
||||||
|
speak_threshold: values.speak_threshold,
|
||||||
|
max_speak_per_round: values.max_speak_per_round,
|
||||||
|
speak_style: 'balanced'
|
||||||
|
},
|
||||||
|
color: values.color
|
||||||
|
}
|
||||||
|
|
||||||
|
if (editingAgent) {
|
||||||
|
await updateAgent(editingAgent.agent_id, data)
|
||||||
|
message.success('更新成功')
|
||||||
|
} else {
|
||||||
|
await createAgent(data)
|
||||||
|
message.success('创建成功')
|
||||||
|
}
|
||||||
|
|
||||||
|
setModalVisible(false)
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除
|
||||||
|
const handleDelete = (agent: Agent) => {
|
||||||
|
Modal.confirm({
|
||||||
|
title: '确认删除',
|
||||||
|
content: `确定要删除Agent "${agent.name}" 吗?`,
|
||||||
|
okText: '删除',
|
||||||
|
okType: 'danger',
|
||||||
|
onOk: async () => {
|
||||||
|
try {
|
||||||
|
await deleteAgent(agent.agent_id)
|
||||||
|
message.success('删除成功')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试
|
||||||
|
const handleTest = (agent: Agent) => {
|
||||||
|
setTestingAgent(agent)
|
||||||
|
setTestResult('')
|
||||||
|
setTestModalVisible(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
const runTest = async (testMessage: string) => {
|
||||||
|
if (!testingAgent) return
|
||||||
|
setTestLoading(true)
|
||||||
|
try {
|
||||||
|
const result = await testAgent(testingAgent.agent_id, testMessage)
|
||||||
|
if (result.success) {
|
||||||
|
setTestResult(result.response || '无响应')
|
||||||
|
} else {
|
||||||
|
setTestResult(`错误: ${result.message}`)
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
setTestResult(`错误: ${(e as Error).message}`)
|
||||||
|
} finally {
|
||||||
|
setTestLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制
|
||||||
|
const handleDuplicate = async (agent: Agent) => {
|
||||||
|
try {
|
||||||
|
await duplicateAgent(agent.agent_id)
|
||||||
|
message.success('复制成功')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换启用
|
||||||
|
const handleToggle = async (agent: Agent, enabled: boolean) => {
|
||||||
|
try {
|
||||||
|
await updateAgent(agent.agent_id, { enabled })
|
||||||
|
message.success(enabled ? '已启用' : '已禁用')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从模板创建
|
||||||
|
const handleCreateFromTemplate = async (templateId: string) => {
|
||||||
|
if (enabledProviders.length === 0) {
|
||||||
|
message.error('请先配置AI接口')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Modal.confirm({
|
||||||
|
title: '选择AI接口',
|
||||||
|
content: (
|
||||||
|
<Select
|
||||||
|
id="template-provider-select"
|
||||||
|
style={{ width: '100%', marginTop: 16 }}
|
||||||
|
placeholder="选择AI接口"
|
||||||
|
options={enabledProviders.map(p => ({
|
||||||
|
value: p.provider_id,
|
||||||
|
label: `${p.name} (${p.model})`
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
onOk: async () => {
|
||||||
|
const select = document.getElementById('template-provider-select') as HTMLSelectElement
|
||||||
|
const providerId = select?.value
|
||||||
|
if (providerId) {
|
||||||
|
try {
|
||||||
|
await createFromTemplate(templateId, providerId)
|
||||||
|
message.success('创建成功')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AI生成系统提示词
|
||||||
|
const handleGeneratePrompt = async () => {
|
||||||
|
const name = form.getFieldValue('name')
|
||||||
|
const role = form.getFieldValue('role')
|
||||||
|
const providerId = form.getFieldValue('provider_id')
|
||||||
|
|
||||||
|
if (!name || !role) {
|
||||||
|
message.warning('请先填写Agent名称和角色')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!providerId) {
|
||||||
|
message.warning('请先选择AI接口')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setGeneratingPrompt(true)
|
||||||
|
try {
|
||||||
|
const result = await generatePrompt({
|
||||||
|
provider_id: providerId,
|
||||||
|
name,
|
||||||
|
role,
|
||||||
|
description: form.getFieldValue('description')
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success && result.prompt) {
|
||||||
|
form.setFieldValue('system_prompt', result.prompt)
|
||||||
|
message.success('提示词生成成功')
|
||||||
|
} else {
|
||||||
|
message.error(result.message || '生成失败')
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
message.error(`生成失败: ${(e as Error).message}`)
|
||||||
|
} finally {
|
||||||
|
setGeneratingPrompt(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ marginLeft: 200 }}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: 24 }}>
|
||||||
|
<Title level={2}>Agent管理</Title>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => openModal()}>
|
||||||
|
创建Agent
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Tabs defaultActiveKey="agents" items={[
|
||||||
|
{
|
||||||
|
key: 'agents',
|
||||||
|
label: 'Agent列表',
|
||||||
|
children: (
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{agents.length > 0 ? (
|
||||||
|
<Row gutter={[16, 16]}>
|
||||||
|
{agents.map((agent) => (
|
||||||
|
<Col key={agent.agent_id} xs={24} sm={12} lg={8} xl={6}>
|
||||||
|
<AgentCard
|
||||||
|
agent={agent}
|
||||||
|
onEdit={openModal}
|
||||||
|
onDelete={handleDelete}
|
||||||
|
onDuplicate={handleDuplicate}
|
||||||
|
onTest={handleTest}
|
||||||
|
onToggleEnabled={handleToggle}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
</Row>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无Agent" />
|
||||||
|
)}
|
||||||
|
</Spin>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'templates',
|
||||||
|
label: '预设模板',
|
||||||
|
children: (
|
||||||
|
<Row gutter={[16, 16]}>
|
||||||
|
{templates.map((template) => (
|
||||||
|
<Col key={template.template_id} xs={24} sm={12} lg={8} xl={6}>
|
||||||
|
<Card
|
||||||
|
hoverable
|
||||||
|
style={{ borderTop: `3px solid ${template.color}` }}
|
||||||
|
onClick={() => handleCreateFromTemplate(template.template_id)}
|
||||||
|
>
|
||||||
|
<Card.Meta
|
||||||
|
avatar={
|
||||||
|
<div style={{
|
||||||
|
width: 48,
|
||||||
|
height: 48,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: template.color,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
color: '#fff'
|
||||||
|
}}>
|
||||||
|
<RobotOutlined style={{ fontSize: 24 }} />
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
title={template.name}
|
||||||
|
description={
|
||||||
|
<Paragraph ellipsis={{ rows: 2 }} style={{ marginBottom: 0 }}>
|
||||||
|
{template.role}
|
||||||
|
</Paragraph>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
</Row>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]} />
|
||||||
|
|
||||||
|
{/* 创建/编辑弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title={editingAgent ? '编辑Agent' : '创建Agent'}
|
||||||
|
open={modalVisible}
|
||||||
|
onOk={handleSave}
|
||||||
|
onCancel={() => setModalVisible(false)}
|
||||||
|
width={700}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item
|
||||||
|
name="name"
|
||||||
|
label="名称"
|
||||||
|
rules={[{ required: true, message: '请输入名称' }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="Agent名称" />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item
|
||||||
|
name="role"
|
||||||
|
label="角色"
|
||||||
|
rules={[{ required: true, message: '请输入角色' }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="如: 产品经理、开发工程师" />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="system_prompt"
|
||||||
|
label={
|
||||||
|
<Space>
|
||||||
|
系统提示词
|
||||||
|
<Tooltip title="根据名称和角色自动生成">
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
size="small"
|
||||||
|
icon={generatingPrompt ? <LoadingOutlined /> : <BulbOutlined />}
|
||||||
|
onClick={handleGeneratePrompt}
|
||||||
|
disabled={generatingPrompt}
|
||||||
|
>
|
||||||
|
{generatingPrompt ? '生成中...' : 'AI生成'}
|
||||||
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
rules={[{ required: true, message: '请输入系统提示词' }]}
|
||||||
|
>
|
||||||
|
<TextArea rows={4} placeholder="定义Agent的行为和专业领域,或点击右上角「AI生成」自动编写" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="provider_id"
|
||||||
|
label="AI接口"
|
||||||
|
rules={[{ required: true, message: '请选择AI接口' }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="选择AI接口"
|
||||||
|
options={enabledProviders.map(p => ({
|
||||||
|
value: p.provider_id,
|
||||||
|
label: `${p.name} (${p.model})`
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="temperature" label="温度 (创造性)">
|
||||||
|
<Slider min={0} max={2} step={0.1} marks={{ 0: '精确', 1: '平衡', 2: '创造' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="max_tokens" label="最大Token数">
|
||||||
|
<Slider min={500} max={8000} step={100} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="memory_enabled" label="启用记忆" valuePropName="checked">
|
||||||
|
<Switch />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="multimodal" label="多模态支持" valuePropName="checked">
|
||||||
|
<Switch />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="color" label="代表颜色">
|
||||||
|
<Input type="color" />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="speak_threshold" label="发言倾向">
|
||||||
|
<Slider min={0} max={1} step={0.1} marks={{ 0: '谨慎', 0.5: '适中', 1: '积极' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="max_speak_per_round" label="每轮最大发言次数">
|
||||||
|
<Slider min={1} max={5} step={1} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 测试弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title={`测试Agent: ${testingAgent?.name}`}
|
||||||
|
open={testModalVisible}
|
||||||
|
onCancel={() => setTestModalVisible(false)}
|
||||||
|
footer={null}
|
||||||
|
width={600}
|
||||||
|
>
|
||||||
|
<Form onFinish={(v) => runTest(v.message)}>
|
||||||
|
<Form.Item name="message" initialValue="你好,请简单介绍一下你自己。">
|
||||||
|
<TextArea rows={2} placeholder="输入测试消息" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item>
|
||||||
|
<Button type="primary" htmlType="submit" loading={testLoading}>
|
||||||
|
发送测试
|
||||||
|
</Button>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
|
||||||
|
{testResult && (
|
||||||
|
<Card title="响应" style={{ marginTop: 16 }}>
|
||||||
|
<Paragraph style={{ whiteSpace: 'pre-wrap' }}>
|
||||||
|
{testResult}
|
||||||
|
</Paragraph>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default AgentManagement
|
||||||
419
frontend/src/pages/ChatRoom.tsx
Normal file
419
frontend/src/pages/ChatRoom.tsx
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
/**
|
||||||
|
* 聊天室页面
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useState, useRef } from 'react'
|
||||||
|
import { useParams, useNavigate } from 'react-router-dom'
|
||||||
|
import {
|
||||||
|
Layout, Card, List, Button, Input, Modal, Form, Select,
|
||||||
|
Typography, Tag, Space, Empty, Spin, message, Avatar, Progress
|
||||||
|
} from 'antd'
|
||||||
|
import {
|
||||||
|
PlusOutlined, PlayCircleOutlined, PauseCircleOutlined,
|
||||||
|
StopOutlined, RobotOutlined, DeleteOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { useChatroomStore } from '../stores/chatroomStore'
|
||||||
|
import { useAgentStore } from '../stores/agentStore'
|
||||||
|
import MessageBubble from '../components/MessageBubble'
|
||||||
|
import TypingIndicator from '../components/TypingIndicator'
|
||||||
|
|
||||||
|
const { Sider, Content } = Layout
|
||||||
|
const { Title, Text, Paragraph } = Typography
|
||||||
|
const { TextArea } = Input
|
||||||
|
|
||||||
|
const ChatRoom: React.FC = () => {
|
||||||
|
const { roomId } = useParams()
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||||
|
|
||||||
|
const {
|
||||||
|
chatrooms, currentRoom, messages, typingAgents, loading,
|
||||||
|
fetchChatrooms, fetchChatroom, createChatroom, deleteChatroom,
|
||||||
|
startDiscussion, pauseDiscussion, resumeDiscussion, stopDiscussion,
|
||||||
|
connectWebSocket, disconnectWebSocket, fetchMessages
|
||||||
|
} = useChatroomStore()
|
||||||
|
|
||||||
|
const { agents, fetchAgents, getAgentById } = useAgentStore()
|
||||||
|
|
||||||
|
const [createModalVisible, setCreateModalVisible] = useState(false)
|
||||||
|
const [objectiveModalVisible, setObjectiveModalVisible] = useState(false)
|
||||||
|
const [objective, setObjective] = useState('')
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchChatrooms()
|
||||||
|
fetchAgents()
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (roomId) {
|
||||||
|
fetchChatroom(roomId)
|
||||||
|
fetchMessages(roomId)
|
||||||
|
connectWebSocket(roomId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
disconnectWebSocket()
|
||||||
|
}
|
||||||
|
}, [roomId])
|
||||||
|
|
||||||
|
// 自动滚动到底部
|
||||||
|
useEffect(() => {
|
||||||
|
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||||
|
}, [messages])
|
||||||
|
|
||||||
|
// 获取正在输入的Agent信息
|
||||||
|
const typingAgentsList = Array.from(typingAgents)
|
||||||
|
.map(id => getAgentById(id))
|
||||||
|
.filter(Boolean) as typeof agents
|
||||||
|
|
||||||
|
// 创建聊天室
|
||||||
|
const handleCreateRoom = async () => {
|
||||||
|
try {
|
||||||
|
const values = await form.validateFields()
|
||||||
|
const room = await createChatroom({
|
||||||
|
name: values.name,
|
||||||
|
description: values.description,
|
||||||
|
agents: values.agents,
|
||||||
|
moderator_agent_id: values.moderator_agent_id
|
||||||
|
})
|
||||||
|
message.success('创建成功')
|
||||||
|
setCreateModalVisible(false)
|
||||||
|
navigate(`/chatroom/${room.room_id}`)
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除聊天室
|
||||||
|
const handleDeleteRoom = (roomId: string, roomName: string) => {
|
||||||
|
Modal.confirm({
|
||||||
|
title: '确认删除',
|
||||||
|
content: `确定要删除聊天室 "${roomName}" 吗?`,
|
||||||
|
okText: '删除',
|
||||||
|
okType: 'danger',
|
||||||
|
onOk: async () => {
|
||||||
|
try {
|
||||||
|
await deleteChatroom(roomId)
|
||||||
|
message.success('删除成功')
|
||||||
|
if (currentRoom?.room_id === roomId) {
|
||||||
|
navigate('/chatroom')
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动讨论
|
||||||
|
const handleStartDiscussion = async () => {
|
||||||
|
if (!currentRoom || !objective.trim()) return
|
||||||
|
try {
|
||||||
|
await startDiscussion(currentRoom.room_id, objective)
|
||||||
|
message.success('讨论已启动')
|
||||||
|
setObjectiveModalVisible(false)
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 暂停/恢复/停止
|
||||||
|
const handlePause = async () => {
|
||||||
|
if (!currentRoom) return
|
||||||
|
try {
|
||||||
|
await pauseDiscussion(currentRoom.room_id)
|
||||||
|
message.info('讨论已暂停')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleResume = async () => {
|
||||||
|
if (!currentRoom) return
|
||||||
|
try {
|
||||||
|
await resumeDiscussion(currentRoom.room_id)
|
||||||
|
message.success('讨论已恢复')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleStop = async () => {
|
||||||
|
if (!currentRoom) return
|
||||||
|
Modal.confirm({
|
||||||
|
title: '确认停止',
|
||||||
|
content: '确定要停止当前讨论吗?',
|
||||||
|
onOk: async () => {
|
||||||
|
try {
|
||||||
|
await stopDiscussion(currentRoom.room_id)
|
||||||
|
message.info('正在停止讨论...')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const enabledAgents = agents.filter(a => a.enabled)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Layout style={{ marginLeft: 200, height: 'calc(100vh - 48px)' }}>
|
||||||
|
{/* 聊天室列表 */}
|
||||||
|
<Sider width={280} theme="light" style={{
|
||||||
|
borderRight: '1px solid #f0f0f0',
|
||||||
|
overflow: 'auto'
|
||||||
|
}}>
|
||||||
|
<div style={{ padding: 16 }}>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
icon={<PlusOutlined />}
|
||||||
|
block
|
||||||
|
onClick={() => setCreateModalVisible(true)}
|
||||||
|
>
|
||||||
|
创建聊天室
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<List
|
||||||
|
dataSource={chatrooms}
|
||||||
|
renderItem={(room) => (
|
||||||
|
<List.Item
|
||||||
|
style={{
|
||||||
|
padding: '12px 16px',
|
||||||
|
cursor: 'pointer',
|
||||||
|
background: room.room_id === roomId ? '#e6f7ff' : 'transparent'
|
||||||
|
}}
|
||||||
|
onClick={() => navigate(`/chatroom/${room.room_id}`)}
|
||||||
|
actions={[
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
danger
|
||||||
|
size="small"
|
||||||
|
icon={<DeleteOutlined />}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
handleDeleteRoom(room.room_id, room.name)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<List.Item.Meta
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>{room.name}</span>
|
||||||
|
<Tag color={
|
||||||
|
room.status === 'active' ? 'green' :
|
||||||
|
room.status === 'paused' ? 'orange' :
|
||||||
|
room.status === 'completed' ? 'blue' : 'default'
|
||||||
|
} style={{ marginLeft: 4 }}>
|
||||||
|
{room.status === 'active' ? '进行中' :
|
||||||
|
room.status === 'paused' ? '已暂停' :
|
||||||
|
room.status === 'completed' ? '已完成' : '空闲'}
|
||||||
|
</Tag>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
description={
|
||||||
|
<Text type="secondary" ellipsis style={{ fontSize: 12 }}>
|
||||||
|
{room.agents.length} 个Agent
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</Sider>
|
||||||
|
|
||||||
|
{/* 聊天内容区 */}
|
||||||
|
<Content style={{ display: 'flex', flexDirection: 'column', background: '#fff' }}>
|
||||||
|
{currentRoom ? (
|
||||||
|
<>
|
||||||
|
{/* 头部 */}
|
||||||
|
<div style={{
|
||||||
|
padding: '12px 24px',
|
||||||
|
borderBottom: '1px solid #f0f0f0',
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
alignItems: 'center'
|
||||||
|
}}>
|
||||||
|
<div>
|
||||||
|
<Title level={4} style={{ margin: 0 }}>{currentRoom.name}</Title>
|
||||||
|
{currentRoom.objective && (
|
||||||
|
<Text type="secondary">{currentRoom.objective}</Text>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Space>
|
||||||
|
{currentRoom.status === 'active' && (
|
||||||
|
<Tag color="green">第 {currentRoom.current_round} 轮</Tag>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentRoom.status === 'idle' && (
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
icon={<PlayCircleOutlined />}
|
||||||
|
onClick={() => setObjectiveModalVisible(true)}
|
||||||
|
>
|
||||||
|
开始讨论
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentRoom.status === 'active' && (
|
||||||
|
<>
|
||||||
|
<Button icon={<PauseCircleOutlined />} onClick={handlePause}>
|
||||||
|
暂停
|
||||||
|
</Button>
|
||||||
|
<Button danger icon={<StopOutlined />} onClick={handleStop}>
|
||||||
|
停止
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentRoom.status === 'paused' && (
|
||||||
|
<>
|
||||||
|
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleResume}>
|
||||||
|
继续
|
||||||
|
</Button>
|
||||||
|
<Button danger icon={<StopOutlined />} onClick={handleStop}>
|
||||||
|
停止
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentRoom.status === 'completed' && (
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
icon={<PlayCircleOutlined />}
|
||||||
|
onClick={() => setObjectiveModalVisible(true)}
|
||||||
|
>
|
||||||
|
新讨论
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Agent列表 */}
|
||||||
|
<div style={{
|
||||||
|
padding: '8px 24px',
|
||||||
|
borderBottom: '1px solid #f0f0f0',
|
||||||
|
display: 'flex',
|
||||||
|
gap: 8,
|
||||||
|
flexWrap: 'wrap'
|
||||||
|
}}>
|
||||||
|
{currentRoom.agents.map(agentId => {
|
||||||
|
const agent = getAgentById(agentId)
|
||||||
|
return agent ? (
|
||||||
|
<Tag
|
||||||
|
key={agentId}
|
||||||
|
color={agent.color}
|
||||||
|
icon={<Avatar size={16} style={{ background: agent.color }}><RobotOutlined /></Avatar>}
|
||||||
|
>
|
||||||
|
{agent.name}
|
||||||
|
</Tag>
|
||||||
|
) : null
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 消息区 */}
|
||||||
|
<div style={{
|
||||||
|
flex: 1,
|
||||||
|
overflow: 'auto',
|
||||||
|
padding: '16px 24px',
|
||||||
|
background: '#fafafa'
|
||||||
|
}}>
|
||||||
|
{messages.length > 0 ? (
|
||||||
|
<>
|
||||||
|
{messages.map(msg => (
|
||||||
|
<MessageBubble
|
||||||
|
key={msg.message_id}
|
||||||
|
message={msg}
|
||||||
|
agent={msg.agent_id ? getAgentById(msg.agent_id) : undefined}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
<TypingIndicator agents={typingAgentsList} />
|
||||||
|
<div ref={messagesEndRef} />
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无消息" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<div style={{
|
||||||
|
flex: 1,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center'
|
||||||
|
}}>
|
||||||
|
<Empty description="选择或创建一个聊天室开始讨论" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Content>
|
||||||
|
|
||||||
|
{/* 创建聊天室弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="创建聊天室"
|
||||||
|
open={createModalVisible}
|
||||||
|
onOk={handleCreateRoom}
|
||||||
|
onCancel={() => setCreateModalVisible(false)}
|
||||||
|
width={600}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item
|
||||||
|
name="name"
|
||||||
|
label="名称"
|
||||||
|
rules={[{ required: true, message: '请输入名称' }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="聊天室名称" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="description" label="描述">
|
||||||
|
<TextArea rows={2} placeholder="聊天室描述" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="agents"
|
||||||
|
label="参与Agent"
|
||||||
|
rules={[{ required: true, message: '请选择至少一个Agent' }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
mode="multiple"
|
||||||
|
placeholder="选择参与讨论的Agent"
|
||||||
|
options={enabledAgents.map(a => ({
|
||||||
|
value: a.agent_id,
|
||||||
|
label: a.name
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="moderator_agent_id" label="主持人Agent">
|
||||||
|
<Select
|
||||||
|
placeholder="选择主持人(用于共识判断)"
|
||||||
|
allowClear
|
||||||
|
options={enabledAgents.map(a => ({
|
||||||
|
value: a.agent_id,
|
||||||
|
label: a.name
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 设置讨论目标弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="设置讨论目标"
|
||||||
|
open={objectiveModalVisible}
|
||||||
|
onOk={handleStartDiscussion}
|
||||||
|
onCancel={() => setObjectiveModalVisible(false)}
|
||||||
|
>
|
||||||
|
<TextArea
|
||||||
|
rows={4}
|
||||||
|
placeholder="输入本次讨论的目标,例如:设计一个用户友好的登录系统"
|
||||||
|
value={objective}
|
||||||
|
onChange={(e) => setObjective(e.target.value)}
|
||||||
|
/>
|
||||||
|
</Modal>
|
||||||
|
</Layout>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ChatRoom
|
||||||
199
frontend/src/pages/Dashboard.tsx
Normal file
199
frontend/src/pages/Dashboard.tsx
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
/**
|
||||||
|
* 控制台页面
|
||||||
|
* 显示系统概览
|
||||||
|
*/
|
||||||
|
import React, { useEffect } from 'react'
|
||||||
|
import { Row, Col, Card, Statistic, List, Typography, Tag, Button, Empty } from 'antd'
|
||||||
|
import {
|
||||||
|
ApiOutlined,
|
||||||
|
RobotOutlined,
|
||||||
|
MessageOutlined,
|
||||||
|
CheckCircleOutlined,
|
||||||
|
ArrowRightOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { useNavigate } from 'react-router-dom'
|
||||||
|
import { useProviderStore } from '../stores/providerStore'
|
||||||
|
import { useAgentStore } from '../stores/agentStore'
|
||||||
|
import { useChatroomStore } from '../stores/chatroomStore'
|
||||||
|
|
||||||
|
const { Title, Text } = Typography
|
||||||
|
|
||||||
|
const Dashboard: React.FC = () => {
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const { providers, fetchProviders } = useProviderStore()
|
||||||
|
const { agents, fetchAgents } = useAgentStore()
|
||||||
|
const { chatrooms, fetchChatrooms } = useChatroomStore()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchProviders()
|
||||||
|
fetchAgents()
|
||||||
|
fetchChatrooms()
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const enabledProviders = providers.filter(p => p.enabled)
|
||||||
|
const enabledAgents = agents.filter(a => a.enabled)
|
||||||
|
const activeRooms = chatrooms.filter(r => r.status === 'active')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ marginLeft: 200 }}>
|
||||||
|
<Title level={2}>控制台</Title>
|
||||||
|
|
||||||
|
{/* 统计卡片 */}
|
||||||
|
<Row gutter={16} style={{ marginBottom: 24 }}>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card hoverable onClick={() => navigate('/providers')}>
|
||||||
|
<Statistic
|
||||||
|
title="AI接口"
|
||||||
|
value={enabledProviders.length}
|
||||||
|
suffix={`/ ${providers.length}`}
|
||||||
|
prefix={<ApiOutlined />}
|
||||||
|
valueStyle={{ color: '#1890ff' }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card hoverable onClick={() => navigate('/agents')}>
|
||||||
|
<Statistic
|
||||||
|
title="Agent"
|
||||||
|
value={enabledAgents.length}
|
||||||
|
suffix={`/ ${agents.length}`}
|
||||||
|
prefix={<RobotOutlined />}
|
||||||
|
valueStyle={{ color: '#52c41a' }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card hoverable onClick={() => navigate('/chatroom')}>
|
||||||
|
<Statistic
|
||||||
|
title="聊天室"
|
||||||
|
value={chatrooms.length}
|
||||||
|
prefix={<MessageOutlined />}
|
||||||
|
valueStyle={{ color: '#722ed1' }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card>
|
||||||
|
<Statistic
|
||||||
|
title="活跃讨论"
|
||||||
|
value={activeRooms.length}
|
||||||
|
prefix={<CheckCircleOutlined />}
|
||||||
|
valueStyle={{ color: activeRooms.length > 0 ? '#fa8c16' : '#999' }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
<Row gutter={16}>
|
||||||
|
{/* 最近聊天室 */}
|
||||||
|
<Col span={12}>
|
||||||
|
<Card
|
||||||
|
title="聊天室"
|
||||||
|
extra={
|
||||||
|
<Button type="link" onClick={() => navigate('/chatroom')}>
|
||||||
|
查看全部 <ArrowRightOutlined />
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{chatrooms.length > 0 ? (
|
||||||
|
<List
|
||||||
|
dataSource={chatrooms.slice(0, 5)}
|
||||||
|
renderItem={(room) => (
|
||||||
|
<List.Item
|
||||||
|
actions={[
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
onClick={() => navigate(`/chatroom/${room.room_id}`)}
|
||||||
|
>
|
||||||
|
进入
|
||||||
|
</Button>
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<List.Item.Meta
|
||||||
|
title={
|
||||||
|
<span>
|
||||||
|
{room.name}
|
||||||
|
<Tag
|
||||||
|
color={
|
||||||
|
room.status === 'active' ? 'green' :
|
||||||
|
room.status === 'paused' ? 'orange' :
|
||||||
|
room.status === 'completed' ? 'blue' : 'default'
|
||||||
|
}
|
||||||
|
style={{ marginLeft: 8 }}
|
||||||
|
>
|
||||||
|
{room.status === 'active' ? '进行中' :
|
||||||
|
room.status === 'paused' ? '已暂停' :
|
||||||
|
room.status === 'completed' ? '已完成' : '空闲'}
|
||||||
|
</Tag>
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
description={
|
||||||
|
<Text type="secondary" ellipsis>
|
||||||
|
{room.objective || room.description || '暂无描述'}
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无聊天室" />
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
|
||||||
|
{/* Agent列表 */}
|
||||||
|
<Col span={12}>
|
||||||
|
<Card
|
||||||
|
title="Agent"
|
||||||
|
extra={
|
||||||
|
<Button type="link" onClick={() => navigate('/agents')}>
|
||||||
|
管理 <ArrowRightOutlined />
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{agents.length > 0 ? (
|
||||||
|
<List
|
||||||
|
dataSource={agents.slice(0, 5)}
|
||||||
|
renderItem={(agent) => (
|
||||||
|
<List.Item>
|
||||||
|
<List.Item.Meta
|
||||||
|
avatar={
|
||||||
|
<div style={{
|
||||||
|
width: 32,
|
||||||
|
height: 32,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: agent.color,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
color: '#fff'
|
||||||
|
}}>
|
||||||
|
<RobotOutlined />
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
title={
|
||||||
|
<span>
|
||||||
|
{agent.name}
|
||||||
|
{!agent.enabled && (
|
||||||
|
<Tag color="default" style={{ marginLeft: 8 }}>已禁用</Tag>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
description={agent.role}
|
||||||
|
/>
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无Agent" />
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Dashboard
|
||||||
245
frontend/src/pages/DiscussionHistory.tsx
Normal file
245
frontend/src/pages/DiscussionHistory.tsx
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
/**
|
||||||
|
* 讨论历史页面
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import {
|
||||||
|
Table, Card, Tag, Typography, Space, Button,
|
||||||
|
Modal, List, Descriptions, Progress, Empty
|
||||||
|
} from 'antd'
|
||||||
|
import {
|
||||||
|
CheckCircleOutlined, CloseCircleOutlined,
|
||||||
|
EyeOutlined, RobotOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { discussionApi } from '../services/api'
|
||||||
|
import type { DiscussionResult } from '../types'
|
||||||
|
import dayjs from 'dayjs'
|
||||||
|
|
||||||
|
const { Title, Text, Paragraph } = Typography
|
||||||
|
|
||||||
|
const DiscussionHistory: React.FC = () => {
|
||||||
|
const [discussions, setDiscussions] = useState<DiscussionResult[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [detailVisible, setDetailVisible] = useState(false)
|
||||||
|
const [selectedDiscussion, setSelectedDiscussion] = useState<DiscussionResult | null>(null)
|
||||||
|
|
||||||
|
const fetchDiscussions = async (page = 1, pageSize = 20) => {
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const result = await discussionApi.list(undefined, pageSize)
|
||||||
|
setDiscussions(result.discussions)
|
||||||
|
setTotal(result.total)
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to fetch discussions:', e)
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchDiscussions()
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const showDetail = (record: DiscussionResult) => {
|
||||||
|
setSelectedDiscussion(record)
|
||||||
|
setDetailVisible(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{
|
||||||
|
title: '讨论目标',
|
||||||
|
dataIndex: 'objective',
|
||||||
|
key: 'objective',
|
||||||
|
width: 300,
|
||||||
|
ellipsis: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '状态',
|
||||||
|
dataIndex: 'consensus_reached',
|
||||||
|
key: 'consensus_reached',
|
||||||
|
width: 120,
|
||||||
|
render: (reached: boolean, record: DiscussionResult) => (
|
||||||
|
<Space>
|
||||||
|
{reached ? (
|
||||||
|
<Tag color="success" icon={<CheckCircleOutlined />}>达成共识</Tag>
|
||||||
|
) : (
|
||||||
|
<Tag color="default" icon={<CloseCircleOutlined />}>未达成</Tag>
|
||||||
|
)}
|
||||||
|
</Space>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '置信度',
|
||||||
|
dataIndex: 'confidence',
|
||||||
|
key: 'confidence',
|
||||||
|
width: 120,
|
||||||
|
render: (confidence: number) => (
|
||||||
|
<Progress
|
||||||
|
percent={Math.round(confidence * 100)}
|
||||||
|
size="small"
|
||||||
|
strokeColor={confidence >= 0.8 ? '#52c41a' : confidence >= 0.5 ? '#faad14' : '#ff4d4f'}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '轮数',
|
||||||
|
dataIndex: 'total_rounds',
|
||||||
|
key: 'total_rounds',
|
||||||
|
width: 80
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '消息数',
|
||||||
|
dataIndex: 'total_messages',
|
||||||
|
key: 'total_messages',
|
||||||
|
width: 80
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '参与Agent',
|
||||||
|
dataIndex: 'participating_agents',
|
||||||
|
key: 'participating_agents',
|
||||||
|
width: 150,
|
||||||
|
render: (agents: string[]) => (
|
||||||
|
<Text type="secondary">{agents.length} 个</Text>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '时间',
|
||||||
|
dataIndex: 'created_at',
|
||||||
|
key: 'created_at',
|
||||||
|
width: 180,
|
||||||
|
render: (time: string) => dayjs(time).format('YYYY-MM-DD HH:mm')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '操作',
|
||||||
|
key: 'actions',
|
||||||
|
width: 80,
|
||||||
|
render: (_: unknown, record: DiscussionResult) => (
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
icon={<EyeOutlined />}
|
||||||
|
onClick={() => showDetail(record)}
|
||||||
|
>
|
||||||
|
详情
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ marginLeft: 200 }}>
|
||||||
|
<Title level={2}>讨论历史</Title>
|
||||||
|
|
||||||
|
<Card>
|
||||||
|
<Table
|
||||||
|
dataSource={discussions}
|
||||||
|
columns={columns}
|
||||||
|
rowKey="discussion_id"
|
||||||
|
loading={loading}
|
||||||
|
pagination={{
|
||||||
|
total,
|
||||||
|
pageSize: 20,
|
||||||
|
showTotal: (t) => `共 ${t} 条记录`
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* 详情弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="讨论详情"
|
||||||
|
open={detailVisible}
|
||||||
|
onCancel={() => setDetailVisible(false)}
|
||||||
|
footer={null}
|
||||||
|
width={800}
|
||||||
|
>
|
||||||
|
{selectedDiscussion && (
|
||||||
|
<div>
|
||||||
|
<Descriptions bordered column={2}>
|
||||||
|
<Descriptions.Item label="讨论目标" span={2}>
|
||||||
|
{selectedDiscussion.objective}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="状态">
|
||||||
|
{selectedDiscussion.consensus_reached ? (
|
||||||
|
<Tag color="success">达成共识</Tag>
|
||||||
|
) : (
|
||||||
|
<Tag color="default">未达成共识</Tag>
|
||||||
|
)}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="置信度">
|
||||||
|
<Progress
|
||||||
|
percent={Math.round(selectedDiscussion.confidence * 100)}
|
||||||
|
size="small"
|
||||||
|
style={{ width: 150 }}
|
||||||
|
/>
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="总轮数">
|
||||||
|
{selectedDiscussion.total_rounds}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="总消息数">
|
||||||
|
{selectedDiscussion.total_messages}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="结束原因" span={2}>
|
||||||
|
{selectedDiscussion.end_reason || '无'}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="开始时间">
|
||||||
|
{dayjs(selectedDiscussion.created_at).format('YYYY-MM-DD HH:mm:ss')}
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="结束时间">
|
||||||
|
{selectedDiscussion.completed_at
|
||||||
|
? dayjs(selectedDiscussion.completed_at).format('YYYY-MM-DD HH:mm:ss')
|
||||||
|
: '进行中'}
|
||||||
|
</Descriptions.Item>
|
||||||
|
</Descriptions>
|
||||||
|
|
||||||
|
<Card title="结果摘要" style={{ marginTop: 16 }}>
|
||||||
|
<Paragraph>{selectedDiscussion.summary || '暂无摘要'}</Paragraph>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card title="行动项" style={{ marginTop: 16 }}>
|
||||||
|
{selectedDiscussion.action_items.length > 0 ? (
|
||||||
|
<List
|
||||||
|
dataSource={selectedDiscussion.action_items}
|
||||||
|
renderItem={(item) => (
|
||||||
|
<List.Item>
|
||||||
|
<CheckCircleOutlined style={{ color: '#52c41a', marginRight: 8 }} />
|
||||||
|
{item}
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无行动项" />
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card title="未解决问题" style={{ marginTop: 16 }}>
|
||||||
|
{selectedDiscussion.unresolved_issues.length > 0 ? (
|
||||||
|
<List
|
||||||
|
dataSource={selectedDiscussion.unresolved_issues}
|
||||||
|
renderItem={(item) => (
|
||||||
|
<List.Item>
|
||||||
|
<CloseCircleOutlined style={{ color: '#faad14', marginRight: 8 }} />
|
||||||
|
{item}
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty description="无未解决问题" />
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card title="Agent贡献" style={{ marginTop: 16 }}>
|
||||||
|
<Space wrap>
|
||||||
|
{Object.entries(selectedDiscussion.agent_contributions).map(([agentId, count]) => (
|
||||||
|
<Tag key={agentId} icon={<RobotOutlined />}>
|
||||||
|
{agentId}: {count}条消息
|
||||||
|
</Tag>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default DiscussionHistory
|
||||||
260
frontend/src/pages/ProviderConfig.tsx
Normal file
260
frontend/src/pages/ProviderConfig.tsx
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
/**
|
||||||
|
* AI接口配置页面
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import {
|
||||||
|
Row, Col, Button, Modal, Form, Input, Select, Switch,
|
||||||
|
InputNumber, Typography, message, Spin, Empty
|
||||||
|
} from 'antd'
|
||||||
|
import { PlusOutlined } from '@ant-design/icons'
|
||||||
|
import { useProviderStore } from '../stores/providerStore'
|
||||||
|
import ProviderCard from '../components/ProviderCard'
|
||||||
|
import type { AIProvider } from '../types'
|
||||||
|
|
||||||
|
const { Title } = Typography
|
||||||
|
|
||||||
|
// 提供商类型选项
|
||||||
|
const PROVIDER_TYPES = [
|
||||||
|
{ value: 'minimax', label: 'MiniMax' },
|
||||||
|
{ value: 'zhipu', label: '智谱AI (ChatGLM)' },
|
||||||
|
{ value: 'openrouter', label: 'OpenRouter' },
|
||||||
|
{ value: 'kimi', label: 'Kimi (月之暗面)' },
|
||||||
|
{ value: 'deepseek', label: 'DeepSeek' },
|
||||||
|
{ value: 'gemini', label: 'Google Gemini' },
|
||||||
|
{ value: 'ollama', label: 'Ollama (本地)' },
|
||||||
|
{ value: 'llmstudio', label: 'LLM Studio (本地)' }
|
||||||
|
]
|
||||||
|
|
||||||
|
const ProviderConfig: React.FC = () => {
|
||||||
|
const {
|
||||||
|
providers,
|
||||||
|
loading,
|
||||||
|
fetchProviders,
|
||||||
|
createProvider,
|
||||||
|
updateProvider,
|
||||||
|
deleteProvider,
|
||||||
|
testProvider
|
||||||
|
} = useProviderStore()
|
||||||
|
|
||||||
|
const [modalVisible, setModalVisible] = useState(false)
|
||||||
|
const [editingProvider, setEditingProvider] = useState<AIProvider | null>(null)
|
||||||
|
const [testingId, setTestingId] = useState<string | null>(null)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchProviders()
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
// 打开创建/编辑弹窗
|
||||||
|
const openModal = (provider?: AIProvider) => {
|
||||||
|
setEditingProvider(provider || null)
|
||||||
|
if (provider) {
|
||||||
|
form.setFieldsValue({
|
||||||
|
provider_type: provider.provider_type,
|
||||||
|
name: provider.name,
|
||||||
|
model: provider.model,
|
||||||
|
api_key: '', // 不显示已有密钥
|
||||||
|
base_url: provider.base_url,
|
||||||
|
use_proxy: provider.use_proxy,
|
||||||
|
http_proxy: provider.proxy_config?.http_proxy,
|
||||||
|
https_proxy: provider.proxy_config?.https_proxy,
|
||||||
|
timeout: provider.timeout
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
form.resetFields()
|
||||||
|
}
|
||||||
|
setModalVisible(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存
|
||||||
|
const handleSave = async () => {
|
||||||
|
try {
|
||||||
|
const values = await form.validateFields()
|
||||||
|
|
||||||
|
const data = {
|
||||||
|
provider_type: values.provider_type,
|
||||||
|
name: values.name,
|
||||||
|
model: values.model,
|
||||||
|
api_key: values.api_key || undefined,
|
||||||
|
base_url: values.base_url || '',
|
||||||
|
use_proxy: values.use_proxy || false,
|
||||||
|
proxy_config: values.use_proxy ? {
|
||||||
|
http_proxy: values.http_proxy,
|
||||||
|
https_proxy: values.https_proxy
|
||||||
|
} : undefined,
|
||||||
|
timeout: values.timeout || 60
|
||||||
|
}
|
||||||
|
|
||||||
|
if (editingProvider) {
|
||||||
|
await updateProvider(editingProvider.provider_id, data)
|
||||||
|
message.success('更新成功')
|
||||||
|
} else {
|
||||||
|
await createProvider(data)
|
||||||
|
message.success('创建成功')
|
||||||
|
}
|
||||||
|
|
||||||
|
setModalVisible(false)
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除
|
||||||
|
const handleDelete = (provider: AIProvider) => {
|
||||||
|
Modal.confirm({
|
||||||
|
title: '确认删除',
|
||||||
|
content: `确定要删除AI接口 "${provider.name}" 吗?`,
|
||||||
|
okText: '删除',
|
||||||
|
okType: 'danger',
|
||||||
|
onOk: async () => {
|
||||||
|
try {
|
||||||
|
await deleteProvider(provider.provider_id)
|
||||||
|
message.success('删除成功')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试
|
||||||
|
const handleTest = async (provider: AIProvider) => {
|
||||||
|
setTestingId(provider.provider_id)
|
||||||
|
try {
|
||||||
|
const result = await testProvider(provider.provider_id)
|
||||||
|
if (result.success) {
|
||||||
|
message.success(`连接成功,延迟: ${result.latency_ms?.toFixed(0)}ms`)
|
||||||
|
} else {
|
||||||
|
message.error(`连接失败: ${result.message}`)
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
} finally {
|
||||||
|
setTestingId(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换启用状态
|
||||||
|
const handleToggle = async (provider: AIProvider, enabled: boolean) => {
|
||||||
|
try {
|
||||||
|
await updateProvider(provider.provider_id, { enabled })
|
||||||
|
message.success(enabled ? '已启用' : '已禁用')
|
||||||
|
} catch (e) {
|
||||||
|
message.error((e as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ marginLeft: 200 }}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: 24 }}>
|
||||||
|
<Title level={2}>AI接口配置</Title>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => openModal()}>
|
||||||
|
添加接口
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{providers.length > 0 ? (
|
||||||
|
<Row gutter={[16, 16]}>
|
||||||
|
{providers.map((provider) => (
|
||||||
|
<Col key={provider.provider_id} xs={24} sm={12} lg={8} xl={6}>
|
||||||
|
<ProviderCard
|
||||||
|
provider={provider}
|
||||||
|
onEdit={openModal}
|
||||||
|
onDelete={handleDelete}
|
||||||
|
onTest={handleTest}
|
||||||
|
onToggleEnabled={handleToggle}
|
||||||
|
testing={testingId === provider.provider_id}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
</Row>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无AI接口配置" />
|
||||||
|
)}
|
||||||
|
</Spin>
|
||||||
|
|
||||||
|
{/* 创建/编辑弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title={editingProvider ? '编辑AI接口' : '添加AI接口'}
|
||||||
|
open={modalVisible}
|
||||||
|
onOk={handleSave}
|
||||||
|
onCancel={() => setModalVisible(false)}
|
||||||
|
width={600}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item
|
||||||
|
name="provider_type"
|
||||||
|
label="提供商类型"
|
||||||
|
rules={[{ required: true, message: '请选择提供商类型' }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
options={PROVIDER_TYPES}
|
||||||
|
placeholder="选择AI提供商"
|
||||||
|
disabled={!!editingProvider}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="name"
|
||||||
|
label="名称"
|
||||||
|
rules={[{ required: true, message: '请输入名称' }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="自定义名称" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="model"
|
||||||
|
label="模型"
|
||||||
|
rules={[{ required: true, message: '请输入模型名称' }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="如: gpt-4-turbo, glm-4" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="api_key"
|
||||||
|
label="API密钥"
|
||||||
|
rules={[{ required: !editingProvider, message: '请输入API密钥' }]}
|
||||||
|
>
|
||||||
|
<Input.Password
|
||||||
|
placeholder={editingProvider ? '留空则不修改' : '输入API密钥'}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="base_url" label="API地址">
|
||||||
|
<Input placeholder="留空使用默认地址" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="use_proxy" label="使用代理" valuePropName="checked">
|
||||||
|
<Switch />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item noStyle shouldUpdate={(prev, curr) => prev.use_proxy !== curr.use_proxy}>
|
||||||
|
{({ getFieldValue }) =>
|
||||||
|
getFieldValue('use_proxy') && (
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="http_proxy" label="HTTP代理">
|
||||||
|
<Input placeholder="http://127.0.0.1:7890" />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="https_proxy" label="HTTPS代理">
|
||||||
|
<Input placeholder="http://127.0.0.1:7890" />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="timeout" label="超时时间(秒)">
|
||||||
|
<InputNumber min={10} max={300} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ProviderConfig
|
||||||
195
frontend/src/services/api.ts
Normal file
195
frontend/src/services/api.ts
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
/**
|
||||||
|
* API服务封装
|
||||||
|
*/
|
||||||
|
import axios from 'axios'
|
||||||
|
import type {
|
||||||
|
AIProvider,
|
||||||
|
Agent,
|
||||||
|
AgentTemplate,
|
||||||
|
ChatRoom,
|
||||||
|
Message,
|
||||||
|
DiscussionResult,
|
||||||
|
TestResponse
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
const api = axios.create({
|
||||||
|
baseURL: '/api',
|
||||||
|
timeout: 30000,
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 响应拦截器
|
||||||
|
api.interceptors.response.use(
|
||||||
|
response => response.data,
|
||||||
|
error => {
|
||||||
|
const message = error.response?.data?.detail || error.message || '请求失败'
|
||||||
|
return Promise.reject(new Error(message))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============ AI接口管理 ============
|
||||||
|
|
||||||
|
export const providerApi = {
|
||||||
|
// 获取所有接口
|
||||||
|
list: (enabledOnly = false): Promise<AIProvider[]> =>
|
||||||
|
api.get('/providers', { params: { enabled_only: enabledOnly } }),
|
||||||
|
|
||||||
|
// 获取单个接口
|
||||||
|
get: (id: string): Promise<AIProvider> =>
|
||||||
|
api.get(`/providers/${id}`),
|
||||||
|
|
||||||
|
// 创建接口
|
||||||
|
create: (data: Partial<AIProvider>): Promise<AIProvider> =>
|
||||||
|
api.post('/providers', data),
|
||||||
|
|
||||||
|
// 更新接口
|
||||||
|
update: (id: string, data: Partial<AIProvider>): Promise<AIProvider> =>
|
||||||
|
api.put(`/providers/${id}`, data),
|
||||||
|
|
||||||
|
// 删除接口
|
||||||
|
delete: (id: string): Promise<void> =>
|
||||||
|
api.delete(`/providers/${id}`),
|
||||||
|
|
||||||
|
// 测试接口
|
||||||
|
test: (id: string): Promise<TestResponse> =>
|
||||||
|
api.post(`/providers/${id}/test`),
|
||||||
|
|
||||||
|
// 测试配置
|
||||||
|
testConfig: (data: Record<string, unknown>): Promise<TestResponse> =>
|
||||||
|
api.post('/providers/test', data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Agent管理 ============
|
||||||
|
|
||||||
|
export const agentApi = {
|
||||||
|
// 获取所有Agent
|
||||||
|
list: (enabledOnly = false): Promise<Agent[]> =>
|
||||||
|
api.get('/agents', { params: { enabled_only: enabledOnly } }),
|
||||||
|
|
||||||
|
// 获取单个Agent
|
||||||
|
get: (id: string): Promise<Agent> =>
|
||||||
|
api.get(`/agents/${id}`),
|
||||||
|
|
||||||
|
// 创建Agent
|
||||||
|
create: (data: Partial<Agent>): Promise<Agent> =>
|
||||||
|
api.post('/agents', data),
|
||||||
|
|
||||||
|
// 更新Agent
|
||||||
|
update: (id: string, data: Partial<Agent>): Promise<Agent> =>
|
||||||
|
api.put(`/agents/${id}`, data),
|
||||||
|
|
||||||
|
// 删除Agent
|
||||||
|
delete: (id: string): Promise<void> =>
|
||||||
|
api.delete(`/agents/${id}`),
|
||||||
|
|
||||||
|
// 测试Agent
|
||||||
|
test: (id: string, message?: string): Promise<TestResponse & { response?: string }> =>
|
||||||
|
api.post(`/agents/${id}/test`, { message }),
|
||||||
|
|
||||||
|
// 复制Agent
|
||||||
|
duplicate: (id: string, newName?: string): Promise<Agent> =>
|
||||||
|
api.post(`/agents/${id}/duplicate`, null, { params: { new_name: newName } }),
|
||||||
|
|
||||||
|
// 获取模板
|
||||||
|
getTemplates: (): Promise<AgentTemplate[]> =>
|
||||||
|
api.get('/agents/templates'),
|
||||||
|
|
||||||
|
// 从模板创建
|
||||||
|
createFromTemplate: (templateId: string, providerId: string): Promise<Agent> =>
|
||||||
|
api.post(`/agents/from-template/${templateId}`, null, { params: { provider_id: providerId } }),
|
||||||
|
|
||||||
|
// AI生成系统提示词
|
||||||
|
generatePrompt: (data: {
|
||||||
|
provider_id: string
|
||||||
|
name: string
|
||||||
|
role: string
|
||||||
|
description?: string
|
||||||
|
}): Promise<{ success: boolean; prompt?: string; message?: string }> =>
|
||||||
|
api.post('/agents/generate-prompt', data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 聊天室管理 ============
|
||||||
|
|
||||||
|
export const chatroomApi = {
|
||||||
|
// 获取所有聊天室
|
||||||
|
list: (): Promise<ChatRoom[]> =>
|
||||||
|
api.get('/chatrooms'),
|
||||||
|
|
||||||
|
// 获取单个聊天室
|
||||||
|
get: (id: string): Promise<ChatRoom> =>
|
||||||
|
api.get(`/chatrooms/${id}`),
|
||||||
|
|
||||||
|
// 创建聊天室
|
||||||
|
create: (data: Partial<ChatRoom>): Promise<ChatRoom> =>
|
||||||
|
api.post('/chatrooms', data),
|
||||||
|
|
||||||
|
// 更新聊天室
|
||||||
|
update: (id: string, data: Partial<ChatRoom>): Promise<ChatRoom> =>
|
||||||
|
api.put(`/chatrooms/${id}`, data),
|
||||||
|
|
||||||
|
// 删除聊天室
|
||||||
|
delete: (id: string): Promise<void> =>
|
||||||
|
api.delete(`/chatrooms/${id}`),
|
||||||
|
|
||||||
|
// 添加Agent
|
||||||
|
addAgent: (roomId: string, agentId: string): Promise<ChatRoom> =>
|
||||||
|
api.post(`/chatrooms/${roomId}/agents/${agentId}`),
|
||||||
|
|
||||||
|
// 移除Agent
|
||||||
|
removeAgent: (roomId: string, agentId: string): Promise<ChatRoom> =>
|
||||||
|
api.delete(`/chatrooms/${roomId}/agents/${agentId}`),
|
||||||
|
|
||||||
|
// 获取消息
|
||||||
|
getMessages: (roomId: string, limit = 50, discussionId?: string): Promise<Message[]> =>
|
||||||
|
api.get(`/chatrooms/${roomId}/messages`, {
|
||||||
|
params: { limit, discussion_id: discussionId }
|
||||||
|
}),
|
||||||
|
|
||||||
|
// 启动讨论
|
||||||
|
startDiscussion: (roomId: string, objective: string): Promise<{ is_active: boolean }> =>
|
||||||
|
api.post(`/chatrooms/${roomId}/start`, { objective }),
|
||||||
|
|
||||||
|
// 暂停讨论
|
||||||
|
pauseDiscussion: (roomId: string): Promise<{ is_active: boolean }> =>
|
||||||
|
api.post(`/chatrooms/${roomId}/pause`),
|
||||||
|
|
||||||
|
// 恢复讨论
|
||||||
|
resumeDiscussion: (roomId: string): Promise<{ is_active: boolean }> =>
|
||||||
|
api.post(`/chatrooms/${roomId}/resume`),
|
||||||
|
|
||||||
|
// 停止讨论
|
||||||
|
stopDiscussion: (roomId: string): Promise<{ is_active: boolean }> =>
|
||||||
|
api.post(`/chatrooms/${roomId}/stop`),
|
||||||
|
|
||||||
|
// 获取状态
|
||||||
|
getStatus: (roomId: string): Promise<{
|
||||||
|
is_active: boolean
|
||||||
|
status: string
|
||||||
|
current_round: number
|
||||||
|
}> =>
|
||||||
|
api.get(`/chatrooms/${roomId}/status`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 讨论结果 ============
|
||||||
|
|
||||||
|
export const discussionApi = {
|
||||||
|
// 获取列表
|
||||||
|
list: (roomId?: string, limit = 20): Promise<{ discussions: DiscussionResult[], total: number }> =>
|
||||||
|
api.get('/discussions', { params: { room_id: roomId, limit } }),
|
||||||
|
|
||||||
|
// 获取单个
|
||||||
|
get: (id: string): Promise<DiscussionResult> =>
|
||||||
|
api.get(`/discussions/${id}`),
|
||||||
|
|
||||||
|
// 获取聊天室历史
|
||||||
|
getRoomHistory: (roomId: string, limit = 10): Promise<DiscussionResult[]> =>
|
||||||
|
api.get(`/discussions/room/${roomId}`, { params: { limit } }),
|
||||||
|
|
||||||
|
// 获取最新
|
||||||
|
getLatest: (roomId: string): Promise<DiscussionResult> =>
|
||||||
|
api.get(`/discussions/room/${roomId}/latest`)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default api
|
||||||
173
frontend/src/services/websocket.ts
Normal file
173
frontend/src/services/websocket.ts
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
/**
|
||||||
|
* WebSocket服务
|
||||||
|
* 管理聊天室的实时连接
|
||||||
|
*/
|
||||||
|
import type { WSMessage, Message } from '../types'
|
||||||
|
|
||||||
|
type MessageHandler = (message: WSMessage) => void
|
||||||
|
type ErrorHandler = (error: Event) => void
|
||||||
|
|
||||||
|
class WebSocketService {
|
||||||
|
private ws: WebSocket | null = null
|
||||||
|
private roomId: string | null = null
|
||||||
|
private messageHandlers: MessageHandler[] = []
|
||||||
|
private errorHandlers: ErrorHandler[] = []
|
||||||
|
private reconnectAttempts = 0
|
||||||
|
private maxReconnectAttempts = 5
|
||||||
|
private reconnectDelay = 1000
|
||||||
|
private heartbeatInterval: number | null = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 连接到聊天室
|
||||||
|
*/
|
||||||
|
connect(roomId: string): Promise<void> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
if (this.ws && this.roomId === roomId) {
|
||||||
|
resolve()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭现有连接
|
||||||
|
this.disconnect()
|
||||||
|
|
||||||
|
this.roomId = roomId
|
||||||
|
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||||
|
const wsUrl = `${protocol}//${window.location.host}/api/chatrooms/ws/${roomId}`
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws = new WebSocket(wsUrl)
|
||||||
|
|
||||||
|
this.ws.onopen = () => {
|
||||||
|
console.log('WebSocket connected:', roomId)
|
||||||
|
this.reconnectAttempts = 0
|
||||||
|
this.startHeartbeat()
|
||||||
|
resolve()
|
||||||
|
}
|
||||||
|
|
||||||
|
this.ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data) as WSMessage
|
||||||
|
this.messageHandlers.forEach(handler => handler(data))
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to parse WebSocket message:', e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.ws.onerror = (error) => {
|
||||||
|
console.error('WebSocket error:', error)
|
||||||
|
this.errorHandlers.forEach(handler => handler(error))
|
||||||
|
reject(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.ws.onclose = () => {
|
||||||
|
console.log('WebSocket disconnected')
|
||||||
|
this.stopHeartbeat()
|
||||||
|
this.attemptReconnect()
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
reject(e)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 断开连接
|
||||||
|
*/
|
||||||
|
disconnect(): void {
|
||||||
|
this.stopHeartbeat()
|
||||||
|
if (this.ws) {
|
||||||
|
this.ws.close()
|
||||||
|
this.ws = null
|
||||||
|
}
|
||||||
|
this.roomId = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送消息
|
||||||
|
*/
|
||||||
|
send(data: unknown): void {
|
||||||
|
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
||||||
|
this.ws.send(typeof data === 'string' ? data : JSON.stringify(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加消息处理器
|
||||||
|
*/
|
||||||
|
onMessage(handler: MessageHandler): () => void {
|
||||||
|
this.messageHandlers.push(handler)
|
||||||
|
return () => {
|
||||||
|
this.messageHandlers = this.messageHandlers.filter(h => h !== handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加错误处理器
|
||||||
|
*/
|
||||||
|
onError(handler: ErrorHandler): () => void {
|
||||||
|
this.errorHandlers.push(handler)
|
||||||
|
return () => {
|
||||||
|
this.errorHandlers = this.errorHandlers.filter(h => h !== handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否已连接
|
||||||
|
*/
|
||||||
|
isConnected(): boolean {
|
||||||
|
return this.ws !== null && this.ws.readyState === WebSocket.OPEN
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前房间ID
|
||||||
|
*/
|
||||||
|
getCurrentRoomId(): string | null {
|
||||||
|
return this.roomId
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 开始心跳
|
||||||
|
*/
|
||||||
|
private startHeartbeat(): void {
|
||||||
|
this.heartbeatInterval = window.setInterval(() => {
|
||||||
|
this.send('ping')
|
||||||
|
}, 30000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 停止心跳
|
||||||
|
*/
|
||||||
|
private stopHeartbeat(): void {
|
||||||
|
if (this.heartbeatInterval) {
|
||||||
|
clearInterval(this.heartbeatInterval)
|
||||||
|
this.heartbeatInterval = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 尝试重连
|
||||||
|
*/
|
||||||
|
private attemptReconnect(): void {
|
||||||
|
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
|
||||||
|
console.log('Max reconnect attempts reached')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this.roomId) return
|
||||||
|
|
||||||
|
this.reconnectAttempts++
|
||||||
|
const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts - 1)
|
||||||
|
|
||||||
|
console.log(`Attempting reconnect in ${delay}ms...`)
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (this.roomId) {
|
||||||
|
this.connect(this.roomId).catch(console.error)
|
||||||
|
}
|
||||||
|
}, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导出单例
|
||||||
|
export const wsService = new WebSocketService()
|
||||||
|
export default wsService
|
||||||
103
frontend/src/stores/agentStore.ts
Normal file
103
frontend/src/stores/agentStore.ts
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
/**
|
||||||
|
* Agent状态管理
|
||||||
|
*/
|
||||||
|
import { create } from 'zustand'
|
||||||
|
import type { Agent, AgentTemplate } from '../types'
|
||||||
|
import { agentApi } from '../services/api'
|
||||||
|
|
||||||
|
interface AgentState {
|
||||||
|
agents: Agent[]
|
||||||
|
templates: AgentTemplate[]
|
||||||
|
loading: boolean
|
||||||
|
error: string | null
|
||||||
|
|
||||||
|
// Actions
|
||||||
|
fetchAgents: () => Promise<void>
|
||||||
|
fetchTemplates: () => Promise<void>
|
||||||
|
createAgent: (data: Partial<Agent>) => Promise<Agent>
|
||||||
|
updateAgent: (id: string, data: Partial<Agent>) => Promise<Agent>
|
||||||
|
deleteAgent: (id: string) => Promise<void>
|
||||||
|
testAgent: (id: string, message?: string) => Promise<{ success: boolean; response?: string }>
|
||||||
|
duplicateAgent: (id: string, newName?: string) => Promise<Agent>
|
||||||
|
createFromTemplate: (templateId: string, providerId: string) => Promise<Agent>
|
||||||
|
generatePrompt: (data: {
|
||||||
|
provider_id: string
|
||||||
|
name: string
|
||||||
|
role: string
|
||||||
|
description?: string
|
||||||
|
}) => Promise<{ success: boolean; prompt?: string; message?: string }>
|
||||||
|
getAgentById: (id: string) => Agent | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useAgentStore = create<AgentState>((set, get) => ({
|
||||||
|
agents: [],
|
||||||
|
templates: [],
|
||||||
|
loading: false,
|
||||||
|
error: null,
|
||||||
|
|
||||||
|
fetchAgents: async () => {
|
||||||
|
set({ loading: true, error: null })
|
||||||
|
try {
|
||||||
|
const agents = await agentApi.list()
|
||||||
|
set({ agents, loading: false })
|
||||||
|
} catch (e) {
|
||||||
|
set({ error: (e as Error).message, loading: false })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
fetchTemplates: async () => {
|
||||||
|
try {
|
||||||
|
const templates = await agentApi.getTemplates()
|
||||||
|
set({ templates })
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to fetch templates:', e)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
createAgent: async (data) => {
|
||||||
|
const agent = await agentApi.create(data)
|
||||||
|
set({ agents: [...get().agents, agent] })
|
||||||
|
return agent
|
||||||
|
},
|
||||||
|
|
||||||
|
updateAgent: async (id, data) => {
|
||||||
|
const agent = await agentApi.update(id, data)
|
||||||
|
set({
|
||||||
|
agents: get().agents.map(a =>
|
||||||
|
a.agent_id === id ? agent : a
|
||||||
|
)
|
||||||
|
})
|
||||||
|
return agent
|
||||||
|
},
|
||||||
|
|
||||||
|
deleteAgent: async (id) => {
|
||||||
|
await agentApi.delete(id)
|
||||||
|
set({
|
||||||
|
agents: get().agents.filter(a => a.agent_id !== id)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
testAgent: async (id, message) => {
|
||||||
|
return await agentApi.test(id, message)
|
||||||
|
},
|
||||||
|
|
||||||
|
duplicateAgent: async (id, newName) => {
|
||||||
|
const agent = await agentApi.duplicate(id, newName)
|
||||||
|
set({ agents: [...get().agents, agent] })
|
||||||
|
return agent
|
||||||
|
},
|
||||||
|
|
||||||
|
createFromTemplate: async (templateId, providerId) => {
|
||||||
|
const agent = await agentApi.createFromTemplate(templateId, providerId)
|
||||||
|
set({ agents: [...get().agents, agent] })
|
||||||
|
return agent
|
||||||
|
},
|
||||||
|
|
||||||
|
generatePrompt: async (data) => {
|
||||||
|
return await agentApi.generatePrompt(data)
|
||||||
|
},
|
||||||
|
|
||||||
|
getAgentById: (id) => {
|
||||||
|
return get().agents.find(a => a.agent_id === id)
|
||||||
|
}
|
||||||
|
}))
|
||||||
183
frontend/src/stores/chatroomStore.ts
Normal file
183
frontend/src/stores/chatroomStore.ts
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
/**
|
||||||
|
* 聊天室状态管理
|
||||||
|
*/
|
||||||
|
import { create } from 'zustand'
|
||||||
|
import type { ChatRoom, Message, WSMessage } from '../types'
|
||||||
|
import { chatroomApi } from '../services/api'
|
||||||
|
import { wsService } from '../services/websocket'
|
||||||
|
|
||||||
|
interface ChatroomState {
|
||||||
|
chatrooms: ChatRoom[]
|
||||||
|
currentRoom: ChatRoom | null
|
||||||
|
messages: Message[]
|
||||||
|
typingAgents: Set<string>
|
||||||
|
loading: boolean
|
||||||
|
error: string | null
|
||||||
|
|
||||||
|
// Actions
|
||||||
|
fetchChatrooms: () => Promise<void>
|
||||||
|
fetchChatroom: (id: string) => Promise<void>
|
||||||
|
createChatroom: (data: Partial<ChatRoom>) => Promise<ChatRoom>
|
||||||
|
updateChatroom: (id: string, data: Partial<ChatRoom>) => Promise<ChatRoom>
|
||||||
|
deleteChatroom: (id: string) => Promise<void>
|
||||||
|
|
||||||
|
// 讨论控制
|
||||||
|
startDiscussion: (roomId: string, objective: string) => Promise<void>
|
||||||
|
pauseDiscussion: (roomId: string) => Promise<void>
|
||||||
|
resumeDiscussion: (roomId: string) => Promise<void>
|
||||||
|
stopDiscussion: (roomId: string) => Promise<void>
|
||||||
|
|
||||||
|
// WebSocket
|
||||||
|
connectWebSocket: (roomId: string) => Promise<void>
|
||||||
|
disconnectWebSocket: () => void
|
||||||
|
|
||||||
|
// 消息
|
||||||
|
fetchMessages: (roomId: string) => Promise<void>
|
||||||
|
addMessage: (message: Message) => void
|
||||||
|
setTyping: (agentId: string, isTyping: boolean) => void
|
||||||
|
|
||||||
|
// 状态更新
|
||||||
|
updateRoomStatus: (status: string, round?: number) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useChatroomStore = create<ChatroomState>((set, get) => ({
|
||||||
|
chatrooms: [],
|
||||||
|
currentRoom: null,
|
||||||
|
messages: [],
|
||||||
|
typingAgents: new Set(),
|
||||||
|
loading: false,
|
||||||
|
error: null,
|
||||||
|
|
||||||
|
fetchChatrooms: async () => {
|
||||||
|
set({ loading: true, error: null })
|
||||||
|
try {
|
||||||
|
const chatrooms = await chatroomApi.list()
|
||||||
|
set({ chatrooms, loading: false })
|
||||||
|
} catch (e) {
|
||||||
|
set({ error: (e as Error).message, loading: false })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
fetchChatroom: async (id) => {
|
||||||
|
set({ loading: true, error: null })
|
||||||
|
try {
|
||||||
|
const room = await chatroomApi.get(id)
|
||||||
|
set({ currentRoom: room, loading: false })
|
||||||
|
} catch (e) {
|
||||||
|
set({ error: (e as Error).message, loading: false })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
createChatroom: async (data) => {
|
||||||
|
const room = await chatroomApi.create(data)
|
||||||
|
set({ chatrooms: [...get().chatrooms, room] })
|
||||||
|
return room
|
||||||
|
},
|
||||||
|
|
||||||
|
updateChatroom: async (id, data) => {
|
||||||
|
const room = await chatroomApi.update(id, data)
|
||||||
|
set({
|
||||||
|
chatrooms: get().chatrooms.map(r =>
|
||||||
|
r.room_id === id ? room : r
|
||||||
|
),
|
||||||
|
currentRoom: get().currentRoom?.room_id === id ? room : get().currentRoom
|
||||||
|
})
|
||||||
|
return room
|
||||||
|
},
|
||||||
|
|
||||||
|
deleteChatroom: async (id) => {
|
||||||
|
await chatroomApi.delete(id)
|
||||||
|
set({
|
||||||
|
chatrooms: get().chatrooms.filter(r => r.room_id !== id),
|
||||||
|
currentRoom: get().currentRoom?.room_id === id ? null : get().currentRoom
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
startDiscussion: async (roomId, objective) => {
|
||||||
|
await chatroomApi.startDiscussion(roomId, objective)
|
||||||
|
get().updateRoomStatus('active')
|
||||||
|
},
|
||||||
|
|
||||||
|
pauseDiscussion: async (roomId) => {
|
||||||
|
await chatroomApi.pauseDiscussion(roomId)
|
||||||
|
get().updateRoomStatus('paused')
|
||||||
|
},
|
||||||
|
|
||||||
|
resumeDiscussion: async (roomId) => {
|
||||||
|
await chatroomApi.resumeDiscussion(roomId)
|
||||||
|
get().updateRoomStatus('active')
|
||||||
|
},
|
||||||
|
|
||||||
|
stopDiscussion: async (roomId) => {
|
||||||
|
await chatroomApi.stopDiscussion(roomId)
|
||||||
|
},
|
||||||
|
|
||||||
|
connectWebSocket: async (roomId) => {
|
||||||
|
await wsService.connect(roomId)
|
||||||
|
|
||||||
|
// 监听消息
|
||||||
|
wsService.onMessage((wsMsg: WSMessage) => {
|
||||||
|
switch (wsMsg.type) {
|
||||||
|
case 'message':
|
||||||
|
if (wsMsg.data) {
|
||||||
|
get().addMessage(wsMsg.data as Message)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'typing':
|
||||||
|
if (wsMsg.agent_id) {
|
||||||
|
get().setTyping(wsMsg.agent_id, wsMsg.is_typing ?? false)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'status':
|
||||||
|
get().updateRoomStatus(wsMsg.status ?? '')
|
||||||
|
break
|
||||||
|
case 'round':
|
||||||
|
get().updateRoomStatus('active', wsMsg.round)
|
||||||
|
break
|
||||||
|
case 'error':
|
||||||
|
set({ error: wsMsg.error ?? '未知错误' })
|
||||||
|
break
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
disconnectWebSocket: () => {
|
||||||
|
wsService.disconnect()
|
||||||
|
},
|
||||||
|
|
||||||
|
fetchMessages: async (roomId) => {
|
||||||
|
try {
|
||||||
|
const messages = await chatroomApi.getMessages(roomId)
|
||||||
|
set({ messages: messages.reverse() })
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to fetch messages:', e)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
addMessage: (message) => {
|
||||||
|
set({ messages: [...get().messages, message] })
|
||||||
|
},
|
||||||
|
|
||||||
|
setTyping: (agentId, isTyping) => {
|
||||||
|
const typingAgents = new Set(get().typingAgents)
|
||||||
|
if (isTyping) {
|
||||||
|
typingAgents.add(agentId)
|
||||||
|
} else {
|
||||||
|
typingAgents.delete(agentId)
|
||||||
|
}
|
||||||
|
set({ typingAgents })
|
||||||
|
},
|
||||||
|
|
||||||
|
updateRoomStatus: (status, round) => {
|
||||||
|
const currentRoom = get().currentRoom
|
||||||
|
if (currentRoom) {
|
||||||
|
set({
|
||||||
|
currentRoom: {
|
||||||
|
...currentRoom,
|
||||||
|
status: status as ChatRoom['status'],
|
||||||
|
current_round: round ?? currentRoom.current_round
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
62
frontend/src/stores/providerStore.ts
Normal file
62
frontend/src/stores/providerStore.ts
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
/**
|
||||||
|
* AI接口状态管理
|
||||||
|
*/
|
||||||
|
import { create } from 'zustand'
|
||||||
|
import type { AIProvider } from '../types'
|
||||||
|
import { providerApi } from '../services/api'
|
||||||
|
|
||||||
|
interface ProviderState {
|
||||||
|
providers: AIProvider[]
|
||||||
|
loading: boolean
|
||||||
|
error: string | null
|
||||||
|
|
||||||
|
// Actions
|
||||||
|
fetchProviders: () => Promise<void>
|
||||||
|
createProvider: (data: Partial<AIProvider>) => Promise<AIProvider>
|
||||||
|
updateProvider: (id: string, data: Partial<AIProvider>) => Promise<AIProvider>
|
||||||
|
deleteProvider: (id: string) => Promise<void>
|
||||||
|
testProvider: (id: string) => Promise<{ success: boolean; message: string }>
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useProviderStore = create<ProviderState>((set, get) => ({
|
||||||
|
providers: [],
|
||||||
|
loading: false,
|
||||||
|
error: null,
|
||||||
|
|
||||||
|
fetchProviders: async () => {
|
||||||
|
set({ loading: true, error: null })
|
||||||
|
try {
|
||||||
|
const providers = await providerApi.list()
|
||||||
|
set({ providers, loading: false })
|
||||||
|
} catch (e) {
|
||||||
|
set({ error: (e as Error).message, loading: false })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
createProvider: async (data) => {
|
||||||
|
const provider = await providerApi.create(data)
|
||||||
|
set({ providers: [...get().providers, provider] })
|
||||||
|
return provider
|
||||||
|
},
|
||||||
|
|
||||||
|
updateProvider: async (id, data) => {
|
||||||
|
const provider = await providerApi.update(id, data)
|
||||||
|
set({
|
||||||
|
providers: get().providers.map(p =>
|
||||||
|
p.provider_id === id ? provider : p
|
||||||
|
)
|
||||||
|
})
|
||||||
|
return provider
|
||||||
|
},
|
||||||
|
|
||||||
|
deleteProvider: async (id) => {
|
||||||
|
await providerApi.delete(id)
|
||||||
|
set({
|
||||||
|
providers: get().providers.filter(p => p.provider_id !== id)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
testProvider: async (id) => {
|
||||||
|
return await providerApi.test(id)
|
||||||
|
}
|
||||||
|
}))
|
||||||
150
frontend/src/types/index.ts
Normal file
150
frontend/src/types/index.ts
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
/**
|
||||||
|
* TypeScript类型定义
|
||||||
|
*/
|
||||||
|
|
||||||
|
// AI接口提供商
|
||||||
|
export interface ProxyConfig {
|
||||||
|
http_proxy?: string
|
||||||
|
https_proxy?: string
|
||||||
|
no_proxy?: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RateLimit {
|
||||||
|
requests_per_minute: number
|
||||||
|
tokens_per_minute: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AIProvider {
|
||||||
|
provider_id: string
|
||||||
|
provider_type: string
|
||||||
|
name: string
|
||||||
|
api_key_masked: string
|
||||||
|
base_url: string
|
||||||
|
model: string
|
||||||
|
use_proxy: boolean
|
||||||
|
proxy_config: ProxyConfig
|
||||||
|
rate_limit: RateLimit
|
||||||
|
timeout: number
|
||||||
|
extra_params: Record<string, unknown>
|
||||||
|
enabled: boolean
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Agent
|
||||||
|
export interface AgentCapabilities {
|
||||||
|
memory_enabled: boolean
|
||||||
|
mcp_tools: string[]
|
||||||
|
skills: string[]
|
||||||
|
multimodal: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AgentBehavior {
|
||||||
|
speak_threshold: number
|
||||||
|
max_speak_per_round: number
|
||||||
|
speak_style: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Agent {
|
||||||
|
agent_id: string
|
||||||
|
name: string
|
||||||
|
role: string
|
||||||
|
system_prompt: string
|
||||||
|
provider_id: string
|
||||||
|
temperature: number
|
||||||
|
max_tokens: number
|
||||||
|
capabilities: AgentCapabilities
|
||||||
|
behavior: AgentBehavior
|
||||||
|
avatar?: string
|
||||||
|
color: string
|
||||||
|
enabled: boolean
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AgentTemplate {
|
||||||
|
template_id: string
|
||||||
|
name: string
|
||||||
|
role: string
|
||||||
|
system_prompt: string
|
||||||
|
color: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// 聊天室
|
||||||
|
export interface ChatRoomConfig {
|
||||||
|
max_rounds: number
|
||||||
|
message_history_size: number
|
||||||
|
consensus_threshold: number
|
||||||
|
round_interval: number
|
||||||
|
allow_user_interrupt: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChatRoom {
|
||||||
|
room_id: string
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
objective: string
|
||||||
|
agents: string[]
|
||||||
|
moderator_agent_id?: string
|
||||||
|
config: ChatRoomConfig
|
||||||
|
status: 'idle' | 'active' | 'paused' | 'completed' | 'error'
|
||||||
|
current_round: number
|
||||||
|
current_discussion_id?: string
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
completed_at?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// 消息
|
||||||
|
export interface Message {
|
||||||
|
message_id: string
|
||||||
|
room_id: string
|
||||||
|
discussion_id: string
|
||||||
|
agent_id?: string
|
||||||
|
content: string
|
||||||
|
message_type: 'text' | 'image' | 'file' | 'system' | 'action'
|
||||||
|
round: number
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// 讨论结果
|
||||||
|
export interface DiscussionResult {
|
||||||
|
discussion_id: string
|
||||||
|
room_id: string
|
||||||
|
objective: string
|
||||||
|
consensus_reached: boolean
|
||||||
|
confidence: number
|
||||||
|
summary: string
|
||||||
|
action_items: string[]
|
||||||
|
unresolved_issues: string[]
|
||||||
|
key_decisions: string[]
|
||||||
|
total_rounds: number
|
||||||
|
total_messages: number
|
||||||
|
participating_agents: string[]
|
||||||
|
agent_contributions: Record<string, number>
|
||||||
|
status: string
|
||||||
|
end_reason: string
|
||||||
|
created_at: string
|
||||||
|
completed_at?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket消息类型
|
||||||
|
export interface WSMessage {
|
||||||
|
type: 'message' | 'status' | 'typing' | 'round' | 'error'
|
||||||
|
data?: unknown
|
||||||
|
status?: string
|
||||||
|
agent_id?: string
|
||||||
|
is_typing?: boolean
|
||||||
|
round?: number
|
||||||
|
total_rounds?: number
|
||||||
|
error?: string
|
||||||
|
timestamp?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// API响应
|
||||||
|
export interface TestResponse {
|
||||||
|
success: boolean
|
||||||
|
message: string
|
||||||
|
model?: string
|
||||||
|
latency_ms?: number
|
||||||
|
}
|
||||||
25
frontend/tsconfig.json
Normal file
25
frontend/tsconfig.json
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"baseUrl": ".",
|
||||||
|
"paths": {
|
||||||
|
"@/*": ["src/*"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"include": ["src"],
|
||||||
|
"references": [{ "path": "./tsconfig.node.json" }]
|
||||||
|
}
|
||||||
10
frontend/tsconfig.node.json
Normal file
10
frontend/tsconfig.node.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"composite": true,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowSyntheticDefaultImports": true
|
||||||
|
},
|
||||||
|
"include": ["vite.config.ts"]
|
||||||
|
}
|
||||||
19
frontend/vite.config.ts
Normal file
19
frontend/vite.config.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
server: {
|
||||||
|
port: 3000,
|
||||||
|
proxy: {
|
||||||
|
'/api': {
|
||||||
|
target: 'http://localhost:8000',
|
||||||
|
changeOrigin: true
|
||||||
|
},
|
||||||
|
'/ws': {
|
||||||
|
target: 'ws://localhost:8000',
|
||||||
|
ws: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user