feat: 更新模块文档,添加详细说明和使用示例
This commit is contained in:
@@ -1,3 +1,36 @@
|
|||||||
"""MineNASAI - 基于NAS的智能个人AI助理"""
|
"""MineNASAI - 基于NAS的智能个人AI助理
|
||||||
|
|
||||||
|
MineNASAI 是一个部署在家用 NAS 上的智能 AI 助理系统,支持:
|
||||||
|
- 多渠道接入(飞书、企业微信等)
|
||||||
|
- 多 LLM 后端(OpenAI、Anthropic、Gemini、国产模型等)
|
||||||
|
- 智能任务路由与复杂度评估
|
||||||
|
- 安全的工具执行与权限管理
|
||||||
|
- Web 终端管理界面
|
||||||
|
|
||||||
|
模块结构:
|
||||||
|
- core: 核心基础设施(配置、日志、缓存、监控、数据库)
|
||||||
|
- agent: AI Agent 运行时和工具管理
|
||||||
|
- gateway: 消息网关和渠道接入
|
||||||
|
- llm: 多模型 LLM 客户端
|
||||||
|
- scheduler: 定时任务调度
|
||||||
|
- webtui: Web 管理界面
|
||||||
|
|
||||||
|
快速开始:
|
||||||
|
>>> from minenasai.core import get_settings, setup_logging
|
||||||
|
>>> from minenasai.llm import get_llm_manager
|
||||||
|
>>> from minenasai.agent import get_agent_runtime
|
||||||
|
>>>
|
||||||
|
>>> # 初始化
|
||||||
|
>>> settings = get_settings()
|
||||||
|
>>> setup_logging(settings)
|
||||||
|
>>>
|
||||||
|
>>> # 获取 LLM 管理器
|
||||||
|
>>> llm_manager = get_llm_manager()
|
||||||
|
>>>
|
||||||
|
>>> # 获取 Agent 运行时
|
||||||
|
>>> agent = await get_agent_runtime()
|
||||||
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
__author__ = "MineNASAI Team"
|
||||||
|
__license__ = "MIT"
|
||||||
|
|||||||
@@ -1,21 +1,53 @@
|
|||||||
"""Agent 模块
|
"""Agent 模块 - AI Agent 运行时和工具管理
|
||||||
|
|
||||||
提供 Agent 运行时、工具管理、会话管理
|
本模块提供 AI Agent 的核心功能:
|
||||||
|
|
||||||
|
Agent 运行时 (runtime):
|
||||||
|
- AgentRuntime: Agent 执行引擎,支持多轮对话和工具调用
|
||||||
|
- get_agent_runtime(): 获取 Agent 运行时单例
|
||||||
|
|
||||||
|
权限管理 (permissions):
|
||||||
|
- PermissionManager: 工具权限管理器
|
||||||
|
- DangerLevel: 危险等级枚举(SAFE/LOW/MEDIUM/HIGH/CRITICAL)
|
||||||
|
- 支持细粒度的路径和参数权限控制
|
||||||
|
- 高危操作需要用户确认
|
||||||
|
|
||||||
|
工具注册 (tool_registry):
|
||||||
|
- ToolRegistry: 工具注册表
|
||||||
|
- @tool 装饰器: 快速注册工具函数
|
||||||
|
- register_builtin_tools(): 注册内置工具
|
||||||
|
|
||||||
|
内置工具 (tools):
|
||||||
|
- read_file: 读取文件内容
|
||||||
|
- list_directory: 列出目录文件
|
||||||
|
- python_eval: 安全的 Python 表达式求值
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.agent import get_agent_runtime, tool
|
||||||
|
>>>
|
||||||
|
>>> # 注册自定义工具
|
||||||
|
>>> @tool(name="my_tool", description="自定义工具")
|
||||||
|
... async def my_tool(param: str) -> str:
|
||||||
|
... return f"处理: {param}"
|
||||||
|
>>>
|
||||||
|
>>> # 运行 Agent
|
||||||
|
>>> agent = await get_agent_runtime()
|
||||||
|
>>> response = await agent.run("帮我处理这个任务")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minenasai.agent.runtime import AgentRuntime, get_agent_runtime
|
|
||||||
from minenasai.agent.tools.basic import get_basic_tools
|
|
||||||
from minenasai.agent.permissions import (
|
from minenasai.agent.permissions import (
|
||||||
DangerLevel,
|
DangerLevel,
|
||||||
PermissionManager,
|
PermissionManager,
|
||||||
get_permission_manager,
|
get_permission_manager,
|
||||||
)
|
)
|
||||||
|
from minenasai.agent.runtime import AgentRuntime, get_agent_runtime
|
||||||
from minenasai.agent.tool_registry import (
|
from minenasai.agent.tool_registry import (
|
||||||
ToolRegistry,
|
ToolRegistry,
|
||||||
get_tool_registry,
|
get_tool_registry,
|
||||||
register_builtin_tools,
|
register_builtin_tools,
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
|
from minenasai.agent.tools.basic import get_basic_tools
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentRuntime",
|
"AgentRuntime",
|
||||||
|
|||||||
@@ -6,17 +6,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core import get_logger, get_settings
|
from minenasai.core import get_logger, get_settings
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DangerLevel(str, Enum):
|
class DangerLevel(StrEnum):
|
||||||
"""危险等级"""
|
"""危险等级
|
||||||
|
|
||||||
|
定义工具操作的风险级别,用于权限控制和确认流程。
|
||||||
|
"""
|
||||||
|
|
||||||
SAFE = "safe" # 只读操作,无风险
|
SAFE = "safe" # 只读操作,无风险
|
||||||
LOW = "low" # 低风险,写入工作目录
|
LOW = "low" # 低风险,写入工作目录
|
||||||
@@ -25,8 +29,11 @@ class DangerLevel(str, Enum):
|
|||||||
CRITICAL = "critical" # 极高危,系统级操作
|
CRITICAL = "critical" # 极高危,系统级操作
|
||||||
|
|
||||||
|
|
||||||
class ConfirmationStatus(str, Enum):
|
class ConfirmationStatus(StrEnum):
|
||||||
"""确认状态"""
|
"""确认状态
|
||||||
|
|
||||||
|
高危操作确认请求的状态。
|
||||||
|
"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
APPROVED = "approved"
|
APPROVED = "approved"
|
||||||
@@ -168,14 +175,14 @@ class PermissionManager:
|
|||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
user_id: str | None = None,
|
_user_id: str | None = None, # 预留:将来用于基于用户的权限检查
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""检查工具执行权限
|
"""检查工具执行权限
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: 工具名称
|
tool_name: 工具名称
|
||||||
params: 执行参数
|
params: 执行参数
|
||||||
user_id: 用户 ID
|
_user_id: 用户 ID(预留参数,将来用于基于用户的权限检查)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(是否允许, 原因)
|
(是否允许, 原因)
|
||||||
@@ -375,7 +382,35 @@ _permission_manager: PermissionManager | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_permission_manager() -> PermissionManager:
|
def get_permission_manager() -> PermissionManager:
|
||||||
"""获取全局权限管理器"""
|
"""获取全局权限管理器单例
|
||||||
|
|
||||||
|
返回全局唯一的权限管理器实例,用于检查工具执行权限
|
||||||
|
和管理高危操作确认流程。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PermissionManager: 全局权限管理器
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> pm = get_permission_manager()
|
||||||
|
>>>
|
||||||
|
>>> # 检查权限
|
||||||
|
>>> allowed, reason = pm.check_permission(
|
||||||
|
... "write_file",
|
||||||
|
... {"path": "/tmp/test.txt"}
|
||||||
|
... )
|
||||||
|
>>>
|
||||||
|
>>> # 检查是否需要确认
|
||||||
|
>>> if pm.requires_confirmation("delete_file"):
|
||||||
|
... request_id = await pm.request_confirmation(
|
||||||
|
... "delete_file",
|
||||||
|
... {"path": "/important/file.txt"}
|
||||||
|
... )
|
||||||
|
... # 等待用户确认...
|
||||||
|
|
||||||
|
Note:
|
||||||
|
权限管理器会根据配置文件中的 security 设置
|
||||||
|
决定哪些危险等级的操作需要用户确认。
|
||||||
|
"""
|
||||||
global _permission_manager
|
global _permission_manager
|
||||||
if _permission_manager is None:
|
if _permission_manager is None:
|
||||||
_permission_manager = PermissionManager()
|
_permission_manager = PermissionManager()
|
||||||
|
|||||||
@@ -6,11 +6,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core import get_audit_logger, get_logger, get_settings
|
from minenasai.core import get_audit_logger, get_logger, get_settings
|
||||||
from minenasai.core.session_store import get_session_store
|
from minenasai.core.session_store import get_session_store
|
||||||
from minenasai.llm import LLMManager, get_llm_manager
|
from minenasai.llm import get_llm_manager
|
||||||
from minenasai.llm.base import Message, Provider, ToolDefinition
|
from minenasai.llm.base import Message, Provider, ToolDefinition
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -6,11 +6,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Coroutine, get_type_hints
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
from minenasai.core import get_logger
|
|
||||||
from minenasai.agent.permissions import DangerLevel, ToolPermission, get_permission_manager
|
from minenasai.agent.permissions import DangerLevel, ToolPermission, get_permission_manager
|
||||||
|
from minenasai.core import get_logger
|
||||||
from minenasai.llm.base import ToolDefinition
|
from minenasai.llm.base import ToolDefinition
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""内置工具集"""
|
"""内置工具集"""
|
||||||
|
|
||||||
from minenasai.agent.tools.basic import (
|
from minenasai.agent.tools.basic import (
|
||||||
read_file_tool,
|
|
||||||
list_directory_tool,
|
list_directory_tool,
|
||||||
python_eval_tool,
|
python_eval_tool,
|
||||||
|
read_file_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core import get_logger
|
from minenasai.core import get_logger
|
||||||
|
|||||||
@@ -1,6 +1,42 @@
|
|||||||
"""核心模块
|
"""核心模块 - 提供系统基础设施
|
||||||
|
|
||||||
提供配置管理、日志系统、数据库、监控、缓存等基础功能
|
本模块包含 MineNASAI 的核心基础功能:
|
||||||
|
|
||||||
|
配置管理 (config):
|
||||||
|
- Settings: 全局配置数据类
|
||||||
|
- get_settings(): 获取配置单例
|
||||||
|
- load_config(): 从文件加载配置
|
||||||
|
|
||||||
|
日志系统 (logging):
|
||||||
|
- setup_logging(): 初始化日志系统
|
||||||
|
- get_logger(): 获取模块日志记录器
|
||||||
|
- AuditLogger: 审计日志记录器
|
||||||
|
|
||||||
|
监控系统 (monitoring):
|
||||||
|
- SystemMetrics: 系统指标收集
|
||||||
|
- HealthChecker: 健康检查器
|
||||||
|
- setup_monitoring(): 初始化监控
|
||||||
|
|
||||||
|
缓存系统 (cache):
|
||||||
|
- MemoryCache: 内存缓存,支持 TTL 和 LRU 淘汰
|
||||||
|
- RateLimiter: 令牌桶限流器
|
||||||
|
|
||||||
|
数据库 (database):
|
||||||
|
- Database: SQLite 异步数据库封装
|
||||||
|
- 支持会话、消息、任务的持久化
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.core import get_settings, setup_logging, get_logger
|
||||||
|
>>>
|
||||||
|
>>> # 初始化配置和日志
|
||||||
|
>>> settings = get_settings()
|
||||||
|
>>> setup_logging(settings)
|
||||||
|
>>> logger = get_logger(__name__)
|
||||||
|
>>>
|
||||||
|
>>> # 使用缓存
|
||||||
|
>>> from minenasai.core import get_response_cache
|
||||||
|
>>> cache = get_response_cache()
|
||||||
|
>>> await cache.set("key", "value", ttl=300)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minenasai.core.cache import (
|
from minenasai.core.cache import (
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -71,10 +72,8 @@ class MemoryCache(Generic[T]):
|
|||||||
"""停止后台清理任务"""
|
"""停止后台清理任务"""
|
||||||
if self._cleanup_task:
|
if self._cleanup_task:
|
||||||
self._cleanup_task.cancel()
|
self._cleanup_task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._cleanup_task
|
await self._cleanup_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
|
|
||||||
async def _cleanup_loop(self) -> None:
|
async def _cleanup_loop(self) -> None:
|
||||||
|
|||||||
@@ -296,7 +296,25 @@ _settings: Settings | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
"""获取全局配置实例"""
|
"""获取全局配置单例
|
||||||
|
|
||||||
|
返回全局唯一的配置实例。首次调用时从配置文件加载,
|
||||||
|
后续调用返回缓存的实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Settings: 全局配置对象
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> settings = get_settings()
|
||||||
|
>>> print(settings.app.name)
|
||||||
|
MineNASAI
|
||||||
|
>>> print(settings.llm.default_model)
|
||||||
|
claude-sonnet-4-20250514
|
||||||
|
|
||||||
|
Note:
|
||||||
|
配置加载优先级:环境变量 > 配置文件 > 默认值
|
||||||
|
使用 reset_settings() 可重置配置(仅用于测试)
|
||||||
|
"""
|
||||||
global _settings
|
global _settings
|
||||||
if _settings is None:
|
if _settings is None:
|
||||||
_settings = load_config()
|
_settings = load_config()
|
||||||
@@ -304,6 +322,13 @@ def get_settings() -> Settings:
|
|||||||
|
|
||||||
|
|
||||||
def reset_settings() -> None:
|
def reset_settings() -> None:
|
||||||
"""重置全局配置(用于测试)"""
|
"""重置全局配置单例
|
||||||
|
|
||||||
|
清除缓存的配置实例,下次调用 get_settings() 时将重新加载配置。
|
||||||
|
主要用于测试场景,生产环境慎用。
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
此函数会清除所有已加载的配置,可能影响正在运行的组件。
|
||||||
|
"""
|
||||||
global _settings
|
global _settings
|
||||||
_settings = None
|
_settings = None
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response
|
from fastapi import FastAPI, Request, Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -161,7 +162,7 @@ class HealthChecker:
|
|||||||
result.latency_ms = (time.time() - start_time) * 1000
|
result.latency_ms = (time.time() - start_time) * 1000
|
||||||
self._results[name] = result
|
self._results[name] = result
|
||||||
return result
|
return result
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
result = ComponentHealth(
|
result = ComponentHealth(
|
||||||
name=name,
|
name=name,
|
||||||
status=HealthStatus.UNHEALTHY,
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterator
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core.config import expand_path
|
from minenasai.core.config import expand_path
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,51 @@
|
|||||||
"""Gateway 模块
|
"""Gateway 模块 - 消息网关和渠道接入
|
||||||
|
|
||||||
提供 WebSocket 服务、通讯渠道接入、智能路由
|
本模块提供消息网关服务,支持多渠道消息接入:
|
||||||
|
|
||||||
|
WebSocket 服务 (server):
|
||||||
|
- 提供实时双向通信
|
||||||
|
- 支持多客户端连接管理
|
||||||
|
- 心跳检测和自动重连
|
||||||
|
|
||||||
|
渠道接入 (channels):
|
||||||
|
- FeishuChannel: 飞书机器人接入
|
||||||
|
- WeWorkChannel: 企业微信机器人接入
|
||||||
|
- 统一的消息格式转换
|
||||||
|
|
||||||
|
智能路由 (router):
|
||||||
|
- SmartRouter: 智能任务路由器
|
||||||
|
- 自动评估消息复杂度
|
||||||
|
- 根据复杂度选择处理策略
|
||||||
|
|
||||||
|
协议定义 (protocol):
|
||||||
|
- ChatMessage: 聊天消息模型
|
||||||
|
- MessageType: 消息类型枚举
|
||||||
|
- ChannelType: 渠道类型枚举
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.gateway.router import SmartRouter
|
||||||
|
>>> from minenasai.gateway.protocol.schema import ChatMessage
|
||||||
|
>>>
|
||||||
|
>>> # 创建路由器
|
||||||
|
>>> router = SmartRouter()
|
||||||
|
>>>
|
||||||
|
>>> # 评估任务复杂度
|
||||||
|
>>> message = ChatMessage(content="帮我写一个 Python 爬虫")
|
||||||
|
>>> complexity = router.route(message)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = []
|
from minenasai.gateway.protocol.schema import (
|
||||||
|
ChannelType,
|
||||||
|
ChatMessage,
|
||||||
|
MessageType,
|
||||||
|
TaskComplexity,
|
||||||
|
)
|
||||||
|
from minenasai.gateway.router import SmartRouter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatMessage",
|
||||||
|
"MessageType",
|
||||||
|
"ChannelType",
|
||||||
|
"TaskComplexity",
|
||||||
|
"SmartRouter",
|
||||||
|
]
|
||||||
|
|||||||
@@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -33,8 +31,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._tenant_access_token: str | None = None
|
self._tenant_access_token: str | None = None
|
||||||
self._token_expires: float = 0
|
self._token_expires: float = 0
|
||||||
|
|
||||||
async def verify_signature(self, request: Any) -> bool:
|
async def verify_signature(self, _request: Any) -> bool:
|
||||||
"""验证飞书签名"""
|
"""验证飞书签名
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_request: HTTP 请求对象(待实现时使用)
|
||||||
|
"""
|
||||||
# TODO: 实现签名验证
|
# TODO: 实现签名验证
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -142,10 +144,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
# 判断是用户还是群聊
|
# 判断是用户还是群聊
|
||||||
chat_id = kwargs.get("chat_id")
|
chat_id = kwargs.get("chat_id")
|
||||||
if chat_id:
|
if chat_id:
|
||||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id"
|
url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id"
|
||||||
receive_id = chat_id
|
receive_id = chat_id
|
||||||
else:
|
else:
|
||||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id"
|
url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id"
|
||||||
receive_id = peer_id
|
receive_id = peer_id
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -33,10 +32,13 @@ class WeworkChannel(BaseChannel):
|
|||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
self._token_expires: float = 0
|
self._token_expires: float = 0
|
||||||
|
|
||||||
async def verify_signature(self, request: Any) -> bool:
|
async def verify_signature(self, _request: Any) -> bool:
|
||||||
"""验证企业微信签名
|
"""验证企业微信签名
|
||||||
|
|
||||||
企业微信使用 SHA1 签名验证
|
企业微信使用 SHA1 签名验证。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_request: HTTP 请求对象(待实现时使用)
|
||||||
"""
|
"""
|
||||||
# TODO: 实现签名验证
|
# TODO: 实现签名验证
|
||||||
# signature = sorted([self.token, timestamp, nonce, echostr])
|
# signature = sorted([self.token, timestamp, nonce, echostr])
|
||||||
@@ -105,9 +107,16 @@ class WeworkChannel(BaseChannel):
|
|||||||
peer_id: str,
|
peer_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
message_type: str = "text",
|
message_type: str = "text",
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:将来支持更多消息参数
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送企业微信消息"""
|
"""发送企业微信消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
peer_id: 接收者用户 ID
|
||||||
|
content: 消息内容
|
||||||
|
message_type: 消息类型
|
||||||
|
**_kwargs: 预留参数,将来支持更多消息选项
|
||||||
|
"""
|
||||||
access_token = await self._get_access_token()
|
access_token = await self._get_access_token()
|
||||||
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
|
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class MessageType(StrEnum):
|
||||||
"""消息类型"""
|
"""消息类型
|
||||||
|
|
||||||
|
定义 WebSocket 通信中的消息类型。
|
||||||
|
"""
|
||||||
|
|
||||||
# 客户端 -> 服务器
|
# 客户端 -> 服务器
|
||||||
CHAT = "chat" # 普通聊天消息
|
CHAT = "chat" # 普通聊天消息
|
||||||
@@ -32,8 +35,11 @@ class MessageType(str, Enum):
|
|||||||
CONFIRM = "confirm" # 确认请求/响应
|
CONFIRM = "confirm" # 确认请求/响应
|
||||||
|
|
||||||
|
|
||||||
class ChannelType(str, Enum):
|
class ChannelType(StrEnum):
|
||||||
"""渠道类型"""
|
"""渠道类型
|
||||||
|
|
||||||
|
定义消息来源的渠道类型。
|
||||||
|
"""
|
||||||
|
|
||||||
WEWORK = "wework"
|
WEWORK = "wework"
|
||||||
FEISHU = "feishu"
|
FEISHU = "feishu"
|
||||||
@@ -41,8 +47,11 @@ class ChannelType(str, Enum):
|
|||||||
WEBSOCKET = "websocket"
|
WEBSOCKET = "websocket"
|
||||||
|
|
||||||
|
|
||||||
class TaskComplexity(str, Enum):
|
class TaskComplexity(StrEnum):
|
||||||
"""任务复杂度"""
|
"""任务复杂度
|
||||||
|
|
||||||
|
用于智能路由,决定任务的处理方式。
|
||||||
|
"""
|
||||||
|
|
||||||
SIMPLE = "simple" # 简单查询,直接回复
|
SIMPLE = "simple" # 简单查询,直接回复
|
||||||
MEDIUM = "medium" # 中等任务,需要工具
|
MEDIUM = "medium" # 中等任务,需要工具
|
||||||
|
|||||||
@@ -6,8 +6,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -152,10 +153,16 @@ async def list_sessions(agent_id: str | None = None) -> dict[str, Any]:
|
|||||||
|
|
||||||
async def handle_chat_message(
|
async def handle_chat_message(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
client_id: str,
|
_client_id: str, # 预留:将来用于客户端标识和会话管理
|
||||||
message: ChatMessage,
|
message: ChatMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""处理聊天消息"""
|
"""处理聊天消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket 连接
|
||||||
|
_client_id: 客户端 ID(预留,将来用于会话管理)
|
||||||
|
message: 聊天消息
|
||||||
|
"""
|
||||||
router = get_router()
|
router = get_router()
|
||||||
|
|
||||||
# 发送思考状态
|
# 发送思考状态
|
||||||
|
|||||||
@@ -1,6 +1,44 @@
|
|||||||
"""LLM 多模型客户端模块
|
"""LLM 多模型客户端模块 - 统一的 AI API 接口
|
||||||
|
|
||||||
支持多个 AI API 提供商的统一接口
|
本模块提供统一的 LLM 客户端接口,支持多个 AI 提供商:
|
||||||
|
|
||||||
|
支持的提供商:
|
||||||
|
- OpenAI (GPT-4, GPT-3.5)
|
||||||
|
- Anthropic (Claude 3)
|
||||||
|
- Google (Gemini)
|
||||||
|
- DeepSeek
|
||||||
|
- Moonshot (月之暗面)
|
||||||
|
- MiniMax
|
||||||
|
- Zhipu (智谱)
|
||||||
|
|
||||||
|
核心类:
|
||||||
|
- BaseLLMClient: LLM 客户端基类
|
||||||
|
- LLMManager: 多模型管理器,支持负载均衡和故障转移
|
||||||
|
- Message: 统一消息格式
|
||||||
|
- LLMResponse: 统一响应格式
|
||||||
|
- ToolCall: 工具调用定义
|
||||||
|
|
||||||
|
特性:
|
||||||
|
- 统一的 API 接口
|
||||||
|
- 自动重试和错误处理
|
||||||
|
- 流式响应支持
|
||||||
|
- 工具调用 (Function Calling) 支持
|
||||||
|
- 响应缓存
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.llm import get_llm_manager, Message
|
||||||
|
>>>
|
||||||
|
>>> # 获取 LLM 管理器
|
||||||
|
>>> manager = get_llm_manager()
|
||||||
|
>>>
|
||||||
|
>>> # 发送消息
|
||||||
|
>>> messages = [Message(role="user", content="你好")]
|
||||||
|
>>> response = await manager.chat(messages)
|
||||||
|
>>> print(response.content)
|
||||||
|
>>>
|
||||||
|
>>> # 流式响应
|
||||||
|
>>> async for chunk in manager.stream(messages):
|
||||||
|
... print(chunk.content, end="")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minenasai.llm.base import BaseLLMClient, LLMResponse, Message, ToolCall
|
from minenasai.llm.base import BaseLLMClient, LLMResponse, Message, ToolCall
|
||||||
|
|||||||
@@ -6,13 +6,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any, AsyncIterator
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class Provider(str, Enum):
|
class Provider(StrEnum):
|
||||||
"""LLM 提供商"""
|
"""LLM 提供商
|
||||||
|
|
||||||
|
支持的 AI API 提供商枚举。
|
||||||
|
"""
|
||||||
|
|
||||||
ANTHROPIC = "anthropic" # Claude
|
ANTHROPIC = "anthropic" # Claude
|
||||||
OPENAI = "openai" # GPT
|
OPENAI = "openai" # GPT
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
"""LLM 客户端实现"""
|
"""LLM 客户端实现"""
|
||||||
|
|
||||||
from minenasai.llm.clients.anthropic import AnthropicClient
|
from minenasai.llm.clients.anthropic import AnthropicClient
|
||||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
|
||||||
from minenasai.llm.clients.deepseek import DeepSeekClient
|
from minenasai.llm.clients.deepseek import DeepSeekClient
|
||||||
from minenasai.llm.clients.zhipu import ZhipuClient
|
from minenasai.llm.clients.gemini import GeminiClient
|
||||||
from minenasai.llm.clients.minimax import MiniMaxClient
|
from minenasai.llm.clients.minimax import MiniMaxClient
|
||||||
from minenasai.llm.clients.moonshot import MoonshotClient
|
from minenasai.llm.clients.moonshot import MoonshotClient
|
||||||
from minenasai.llm.clients.gemini import GeminiClient
|
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||||
|
from minenasai.llm.clients.zhipu import ZhipuClient
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicClient",
|
"AnthropicClient",
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ class AnthropicClient(BaseLLMClient):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""发送聊天请求"""
|
"""发送聊天请求"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
@@ -142,7 +143,7 @@ class AnthropicClient(BaseLLMClient):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> AsyncIterator[StreamChunk]:
|
) -> AsyncIterator[StreamChunk]:
|
||||||
"""流式聊天"""
|
"""流式聊天"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ Gemini API 使用独特的接口格式
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -76,7 +77,7 @@ class GeminiClient(BaseLLMClient):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""发送聊天请求"""
|
"""发送聊天请求"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
@@ -164,10 +165,10 @@ class GeminiClient(BaseLLMClient):
|
|||||||
system: str | None = None,
|
system: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
_tools: list[ToolDefinition] | None = None, # 暂不支持流式工具调用
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> AsyncIterator[StreamChunk]:
|
) -> AsyncIterator[StreamChunk]:
|
||||||
"""流式聊天"""
|
"""流式聊天(暂不支持工具调用)"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ MiniMax 使用 OpenAI 兼容接口
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from minenasai.llm.base import Provider
|
from minenasai.llm.base import Provider
|
||||||
|
|||||||
@@ -6,7 +6,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -87,7 +88,7 @@ class OpenAICompatClient(BaseLLMClient):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""发送聊天请求"""
|
"""发送聊天请求"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
@@ -154,7 +155,7 @@ class OpenAICompatClient(BaseLLMClient):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
tools: list[ToolDefinition] | None = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
**kwargs: Any,
|
**_kwargs: Any, # 预留:基类接口扩展参数
|
||||||
) -> AsyncIterator[StreamChunk]:
|
) -> AsyncIterator[StreamChunk]:
|
||||||
"""流式聊天"""
|
"""流式聊天"""
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
|||||||
@@ -5,7 +5,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core import get_logger, get_settings
|
from minenasai.core import get_logger, get_settings
|
||||||
from minenasai.llm.base import (
|
from minenasai.llm.base import (
|
||||||
@@ -142,10 +143,10 @@ class LLMManager:
|
|||||||
client = self._create_client(provider)
|
client = self._create_client(provider)
|
||||||
if client:
|
if client:
|
||||||
self._clients[provider] = client
|
self._clients[provider] = client
|
||||||
logger.info(f"LLM 客户端已初始化", provider=provider.display_name)
|
logger.info("LLM 客户端已初始化", provider=provider.display_name)
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info(f"LLM Manager 初始化完成", providers=list(self._clients.keys()))
|
logger.info("LLM Manager 初始化完成", providers=list(self._clients.keys()))
|
||||||
|
|
||||||
def get_client(self, provider: Provider) -> BaseLLMClient | None:
|
def get_client(self, provider: Provider) -> BaseLLMClient | None:
|
||||||
"""获取指定提供商的客户端"""
|
"""获取指定提供商的客户端"""
|
||||||
@@ -334,7 +335,30 @@ _llm_manager: LLMManager | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_llm_manager() -> LLMManager:
|
def get_llm_manager() -> LLMManager:
|
||||||
"""获取全局 LLM 管理器"""
|
"""获取全局 LLM 管理器单例
|
||||||
|
|
||||||
|
返回全局唯一的 LLM 管理器实例。管理器会根据配置自动初始化
|
||||||
|
可用的 LLM 客户端。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMManager: 全局 LLM 管理器
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> manager = get_llm_manager()
|
||||||
|
>>> # 发送消息
|
||||||
|
>>> response = await manager.chat([
|
||||||
|
... Message(role="user", content="你好")
|
||||||
|
... ])
|
||||||
|
>>> print(response.content)
|
||||||
|
>>>
|
||||||
|
>>> # 流式响应
|
||||||
|
>>> async for chunk in manager.stream(messages):
|
||||||
|
... print(chunk.content, end="")
|
||||||
|
|
||||||
|
Note:
|
||||||
|
首次调用时会自动初始化所有配置了 API Key 的客户端。
|
||||||
|
支持的提供商取决于环境变量或配置文件中的 API Key 设置。
|
||||||
|
"""
|
||||||
global _llm_manager
|
global _llm_manager
|
||||||
if _llm_manager is None:
|
if _llm_manager is None:
|
||||||
_llm_manager = LLMManager()
|
_llm_manager = LLMManager()
|
||||||
|
|||||||
@@ -1,6 +1,43 @@
|
|||||||
"""定时任务调度模块
|
"""定时任务调度模块 - Cron 任务调度
|
||||||
|
|
||||||
提供 Cron 任务调度功能
|
本模块提供类似 Linux cron 的定时任务调度功能:
|
||||||
|
|
||||||
|
核心类:
|
||||||
|
- CronScheduler: 调度器,管理所有定时任务
|
||||||
|
- CronJob: 定时任务定义
|
||||||
|
|
||||||
|
Cron 表达式:
|
||||||
|
支持标准 5 字段 cron 表达式:分 时 日 月 周
|
||||||
|
- * : 任意值
|
||||||
|
- */n : 每隔 n
|
||||||
|
- n-m : 范围
|
||||||
|
- n,m : 列表
|
||||||
|
|
||||||
|
预设表达式:
|
||||||
|
- @hourly : 每小时
|
||||||
|
- @daily : 每天
|
||||||
|
- @weekly : 每周
|
||||||
|
- @monthly : 每月
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.scheduler import get_scheduler
|
||||||
|
>>>
|
||||||
|
>>> # 获取调度器
|
||||||
|
>>> scheduler = get_scheduler()
|
||||||
|
>>>
|
||||||
|
>>> # 添加定时任务
|
||||||
|
>>> async def my_task():
|
||||||
|
... print("任务执行中...")
|
||||||
|
>>>
|
||||||
|
>>> scheduler.add_job(
|
||||||
|
... job_id="my_job",
|
||||||
|
... cron_expr="0 9 * * *", # 每天 9:00
|
||||||
|
... func=my_task,
|
||||||
|
... description="每日任务"
|
||||||
|
... )
|
||||||
|
>>>
|
||||||
|
>>> # 启动调度器
|
||||||
|
>>> await scheduler.start()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minenasai.scheduler.cron import CronJob, CronScheduler, get_scheduler
|
from minenasai.scheduler.cron import CronJob, CronScheduler, get_scheduler
|
||||||
|
|||||||
@@ -6,21 +6,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import contextlib
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any
|
||||||
|
|
||||||
from minenasai.core import get_logger
|
from minenasai.core import get_logger
|
||||||
from minenasai.core.database import get_database
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JobStatus(str, Enum):
|
class JobStatus(StrEnum):
|
||||||
"""任务状态"""
|
"""任务状态
|
||||||
|
|
||||||
|
定时任务的执行状态。
|
||||||
|
"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
@@ -283,10 +286,8 @@ class CronScheduler:
|
|||||||
self._running = False
|
self._running = False
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._task
|
await self._task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
logger.info("定时调度器已停止")
|
logger.info("定时调度器已停止")
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
async def _run_loop(self) -> None:
|
||||||
|
|||||||
@@ -1,6 +1,36 @@
|
|||||||
"""Web TUI 模块
|
"""Web TUI 模块 - Web 管理界面
|
||||||
|
|
||||||
提供 Web 终端界面、SSH 连接管理
|
本模块提供基于 Web 的终端管理界面:
|
||||||
|
|
||||||
|
认证管理 (auth):
|
||||||
|
- AuthManager: 认证管理器
|
||||||
|
- AuthToken: 认证令牌
|
||||||
|
- 支持 JWT 令牌认证
|
||||||
|
- 令牌刷新和撤销
|
||||||
|
|
||||||
|
SSH 管理 (ssh_manager):
|
||||||
|
- SSHManager: SSH 连接管理器
|
||||||
|
- SSHSession: SSH 会话封装
|
||||||
|
- 支持多会话管理
|
||||||
|
- WebSocket 实时终端
|
||||||
|
|
||||||
|
Web 服务 (server):
|
||||||
|
- FastAPI 应用
|
||||||
|
- 静态文件服务
|
||||||
|
- WebSocket 终端
|
||||||
|
- RESTful API
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from minenasai.webtui import get_auth_manager
|
||||||
|
>>>
|
||||||
|
>>> # 获取认证管理器
|
||||||
|
>>> auth = get_auth_manager()
|
||||||
|
>>>
|
||||||
|
>>> # 生成令牌
|
||||||
|
>>> token = auth.generate_token(user_id="admin")
|
||||||
|
>>>
|
||||||
|
>>> # 验证令牌
|
||||||
|
>>> user_info = auth.verify_token(token.token)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minenasai.webtui.auth import AuthManager, AuthToken, get_auth_manager
|
from minenasai.webtui.auth import AuthManager, AuthToken, get_auth_manager
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import FileResponse, HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from minenasai.core import get_logger, get_settings, setup_logging
|
from minenasai.core import get_logger, get_settings, setup_logging
|
||||||
@@ -95,8 +96,12 @@ manager = ConnectionManager()
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""应用生命周期管理"""
|
"""应用生命周期管理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_app: FastAPI 应用实例(lifespan 标准签名要求)
|
||||||
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
setup_logging(settings.logging)
|
setup_logging(settings.logging)
|
||||||
logger.info("Web TUI 服务启动", port=settings.webtui.port)
|
logger.info("Web TUI 服务启动", port=settings.webtui.port)
|
||||||
|
|||||||
@@ -6,10 +6,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from collections.abc import Callable
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
import paramiko
|
import paramiko
|
||||||
|
|
||||||
@@ -128,10 +129,8 @@ class SSHSession:
|
|||||||
|
|
||||||
if self._read_task:
|
if self._read_task:
|
||||||
self._read_task.cancel()
|
self._read_task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._read_task
|
await self._read_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.channel:
|
if self.channel:
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from minenasai.core import Settings, get_settings, load_config, reset_settings
|
from minenasai.core import Settings, get_settings, load_config, reset_settings
|
||||||
from minenasai.core.config import expand_path
|
from minenasai.core.config import expand_path
|
||||||
|
|
||||||
|
|||||||
345
tests/test_database.py
Normal file
345
tests/test_database.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""数据库模块测试
|
||||||
|
|
||||||
|
测试 SQLite 异步数据库操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from minenasai.core.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase:
|
||||||
|
"""数据库基本操作测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db(self, tmp_path):
|
||||||
|
"""创建临时数据库"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
await db.connect()
|
||||||
|
yield db
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_and_close(self, tmp_path):
|
||||||
|
"""测试数据库连接和关闭"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
|
||||||
|
# 连接前访问 conn 应该报错
|
||||||
|
with pytest.raises(RuntimeError, match="数据库未连接"):
|
||||||
|
_ = db.conn
|
||||||
|
|
||||||
|
# 连接
|
||||||
|
await db.connect()
|
||||||
|
assert db.conn is not None
|
||||||
|
|
||||||
|
# 关闭
|
||||||
|
await db.close()
|
||||||
|
assert db._conn is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentOperations:
|
||||||
|
"""Agent 操作测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db(self, tmp_path):
|
||||||
|
"""创建临时数据库"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
await db.connect()
|
||||||
|
yield db
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent(self, db):
|
||||||
|
"""测试创建 Agent"""
|
||||||
|
agent = await db.create_agent(
|
||||||
|
agent_id="test-agent-1",
|
||||||
|
name="Test Agent",
|
||||||
|
workspace_path="/tmp/workspace",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent["id"] == "test-agent-1"
|
||||||
|
assert agent["name"] == "Test Agent"
|
||||||
|
assert agent["workspace_path"] == "/tmp/workspace"
|
||||||
|
assert agent["model"] == "claude-sonnet-4-20250514"
|
||||||
|
assert agent["sandbox_mode"] == "workspace"
|
||||||
|
assert "created_at" in agent
|
||||||
|
assert "updated_at" in agent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_custom_model(self, db):
|
||||||
|
"""测试使用自定义模型创建 Agent"""
|
||||||
|
agent = await db.create_agent(
|
||||||
|
agent_id="test-agent-2",
|
||||||
|
name="Custom Agent",
|
||||||
|
workspace_path="/tmp/workspace",
|
||||||
|
model="gpt-4",
|
||||||
|
sandbox_mode="strict",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent["model"] == "gpt-4"
|
||||||
|
assert agent["sandbox_mode"] == "strict"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent(self, db):
|
||||||
|
"""测试获取 Agent"""
|
||||||
|
# 创建 Agent
|
||||||
|
await db.create_agent(
|
||||||
|
agent_id="test-agent-3",
|
||||||
|
name="Get Test Agent",
|
||||||
|
workspace_path="/tmp/workspace",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取存在的 Agent
|
||||||
|
agent = await db.get_agent("test-agent-3")
|
||||||
|
assert agent is not None
|
||||||
|
assert agent["name"] == "Get Test Agent"
|
||||||
|
|
||||||
|
# 获取不存在的 Agent
|
||||||
|
agent = await db.get_agent("nonexistent")
|
||||||
|
assert agent is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_agents(self, db):
|
||||||
|
"""测试列出所有 Agent"""
|
||||||
|
# 初始为空
|
||||||
|
agents = await db.list_agents()
|
||||||
|
assert len(agents) == 0
|
||||||
|
|
||||||
|
# 创建多个 Agent
|
||||||
|
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
|
||||||
|
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
|
||||||
|
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
|
||||||
|
|
||||||
|
# 列出所有
|
||||||
|
agents = await db.list_agents()
|
||||||
|
assert len(agents) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionOperations:
|
||||||
|
"""Session 操作测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db(self, tmp_path):
|
||||||
|
"""创建临时数据库并添加测试 Agent"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
await db.connect()
|
||||||
|
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||||
|
yield db
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_session(self, db):
|
||||||
|
"""测试创建会话"""
|
||||||
|
session = await db.create_session(
|
||||||
|
agent_id="test-agent",
|
||||||
|
channel="websocket",
|
||||||
|
peer_id="user-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session["agent_id"] == "test-agent"
|
||||||
|
assert session["channel"] == "websocket"
|
||||||
|
assert session["peer_id"] == "user-123"
|
||||||
|
assert session["status"] == "active"
|
||||||
|
assert "session_key" in session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_session_with_metadata(self, db):
|
||||||
|
"""测试创建带元数据的会话"""
|
||||||
|
metadata = {"client": "web", "version": "1.0"}
|
||||||
|
session = await db.create_session(
|
||||||
|
agent_id="test-agent",
|
||||||
|
channel="websocket",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session["metadata"] == metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session(self, db):
|
||||||
|
"""测试获取会话"""
|
||||||
|
# 创建会话
|
||||||
|
created = await db.create_session(
|
||||||
|
agent_id="test-agent",
|
||||||
|
channel="websocket",
|
||||||
|
metadata={"test": "data"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取会话
|
||||||
|
session = await db.get_session(created["id"])
|
||||||
|
assert session is not None
|
||||||
|
assert session["id"] == created["id"]
|
||||||
|
assert session["metadata"] == {"test": "data"}
|
||||||
|
|
||||||
|
# 获取不存在的会话
|
||||||
|
session = await db.get_session("nonexistent-id")
|
||||||
|
assert session is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_session_status(self, db):
|
||||||
|
"""测试更新会话状态"""
|
||||||
|
# 创建会话
|
||||||
|
session = await db.create_session(
|
||||||
|
agent_id="test-agent",
|
||||||
|
channel="websocket",
|
||||||
|
)
|
||||||
|
assert session["status"] == "active"
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
await db.update_session_status(session["id"], "closed")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
updated = await db.get_session(session["id"])
|
||||||
|
assert updated["status"] == "closed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_active_sessions(self, db):
|
||||||
|
"""测试列出活跃会话"""
|
||||||
|
# 创建多个会话
|
||||||
|
session1 = await db.create_session("test-agent", "websocket")
|
||||||
|
session2 = await db.create_session("test-agent", "feishu")
|
||||||
|
await db.create_session("test-agent", "wework")
|
||||||
|
|
||||||
|
# 关闭一个会话
|
||||||
|
await db.update_session_status(session2["id"], "closed")
|
||||||
|
|
||||||
|
# 列出所有活跃会话
|
||||||
|
active = await db.list_active_sessions()
|
||||||
|
assert len(active) == 2
|
||||||
|
|
||||||
|
# 按 agent_id 过滤
|
||||||
|
active = await db.list_active_sessions("test-agent")
|
||||||
|
assert len(active) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageOperations:
|
||||||
|
"""Message 操作测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db(self, tmp_path):
|
||||||
|
"""创建临时数据库并添加测试会话"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
await db.connect()
|
||||||
|
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||||
|
session = await db.create_session("test-agent", "websocket")
|
||||||
|
db._test_session_id = session["id"]
|
||||||
|
yield db
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message(self, db):
|
||||||
|
"""测试添加消息"""
|
||||||
|
session_id = db._test_session_id
|
||||||
|
|
||||||
|
message = await db.add_message(
|
||||||
|
session_id=session_id,
|
||||||
|
role="user",
|
||||||
|
content="Hello, world!",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message["session_id"] == session_id
|
||||||
|
assert message["role"] == "user"
|
||||||
|
assert message["content"] == "Hello, world!"
|
||||||
|
assert message["tokens_used"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message_with_tool_calls(self, db):
|
||||||
|
"""测试添加带工具调用的消息"""
|
||||||
|
session_id = db._test_session_id
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
|
||||||
|
]
|
||||||
|
message = await db.add_message(
|
||||||
|
session_id=session_id,
|
||||||
|
role="assistant",
|
||||||
|
content=None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
tokens_used=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message["tool_calls"] == tool_calls
|
||||||
|
assert message["tokens_used"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_messages(self, db):
|
||||||
|
"""测试获取会话消息"""
|
||||||
|
session_id = db._test_session_id
|
||||||
|
|
||||||
|
# 添加多条消息
|
||||||
|
await db.add_message(session_id, "user", "Message 1")
|
||||||
|
await db.add_message(session_id, "assistant", "Response 1")
|
||||||
|
await db.add_message(session_id, "user", "Message 2")
|
||||||
|
await db.add_message(session_id, "assistant", "Response 2")
|
||||||
|
|
||||||
|
# 获取所有消息
|
||||||
|
messages = await db.get_messages(session_id)
|
||||||
|
assert len(messages) == 4
|
||||||
|
# 验证角色交替
|
||||||
|
roles = [m["role"] for m in messages]
|
||||||
|
assert roles.count("user") == 2
|
||||||
|
assert roles.count("assistant") == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_messages_with_limit(self, db):
|
||||||
|
"""测试获取消息数量限制"""
|
||||||
|
session_id = db._test_session_id
|
||||||
|
|
||||||
|
# 添加多条消息
|
||||||
|
for i in range(10):
|
||||||
|
await db.add_message(session_id, "user", f"Message {i}")
|
||||||
|
|
||||||
|
# 限制返回数量
|
||||||
|
messages = await db.get_messages(session_id, limit=5)
|
||||||
|
assert len(messages) == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLog:
|
||||||
|
"""审计日志测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db(self, tmp_path):
|
||||||
|
"""创建临时数据库"""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
await db.connect()
|
||||||
|
yield db
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_audit_log(self, db):
|
||||||
|
"""测试添加审计日志"""
|
||||||
|
# 添加日志不应该报错
|
||||||
|
await db.add_audit_log(
|
||||||
|
agent_id="test-agent",
|
||||||
|
tool_name="read_file",
|
||||||
|
danger_level="safe",
|
||||||
|
params={"path": "/tmp/test.txt"},
|
||||||
|
result="success",
|
||||||
|
duration_ms=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_audit_log_minimal(self, db):
|
||||||
|
"""测试添加最小审计日志"""
|
||||||
|
await db.add_audit_log(
|
||||||
|
agent_id=None,
|
||||||
|
tool_name="python_eval",
|
||||||
|
danger_level="low",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalDatabase:
|
||||||
|
"""全局数据库实例测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_functions(self):
|
||||||
|
"""测试导入全局函数"""
|
||||||
|
from minenasai.core.database import close_database, get_database
|
||||||
|
|
||||||
|
assert callable(get_database)
|
||||||
|
assert callable(close_database)
|
||||||
@@ -142,3 +142,159 @@ class TestRouterEdgeCases:
|
|||||||
result = self.router.evaluate("查看 /tmp/test.txt 文件内容")
|
result = self.router.evaluate("查看 /tmp/test.txt 文件内容")
|
||||||
|
|
||||||
assert result["complexity"] in [TaskComplexity.SIMPLE, TaskComplexity.MEDIUM]
|
assert result["complexity"] in [TaskComplexity.SIMPLE, TaskComplexity.MEDIUM]
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionManager:
|
||||||
|
"""WebSocket 连接管理器测试"""
|
||||||
|
|
||||||
|
def test_import_manager(self):
|
||||||
|
"""测试导入连接管理器"""
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
assert manager.active_connections == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_and_disconnect(self):
|
||||||
|
"""测试连接和断开"""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
# Mock WebSocket
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.accept = AsyncMock()
|
||||||
|
|
||||||
|
# 连接
|
||||||
|
await manager.connect(mock_ws, "client-1")
|
||||||
|
assert "client-1" in manager.active_connections
|
||||||
|
mock_ws.accept.assert_called_once()
|
||||||
|
|
||||||
|
# 断开
|
||||||
|
manager.disconnect("client-1")
|
||||||
|
assert "client-1" not in manager.active_connections
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_nonexistent(self):
|
||||||
|
"""测试断开不存在的连接"""
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
# 不应该抛出异常
|
||||||
|
manager.disconnect("nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_message(self):
|
||||||
|
"""测试发送消息"""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
# Mock WebSocket
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.accept = AsyncMock()
|
||||||
|
mock_ws.send_json = AsyncMock()
|
||||||
|
|
||||||
|
# 连接
|
||||||
|
await manager.connect(mock_ws, "client-1")
|
||||||
|
|
||||||
|
# 发送消息
|
||||||
|
await manager.send_message("client-1", {"type": "test"})
|
||||||
|
mock_ws.send_json.assert_called_once_with({"type": "test"})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_message_to_nonexistent(self):
|
||||||
|
"""测试发送消息给不存在的客户端"""
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
# 不应该抛出异常
|
||||||
|
await manager.send_message("nonexistent", {"type": "test"})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast(self):
|
||||||
|
"""测试广播消息"""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from minenasai.gateway.server import ConnectionManager
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
# Mock 多个 WebSocket
|
||||||
|
mock_ws1 = AsyncMock()
|
||||||
|
mock_ws1.accept = AsyncMock()
|
||||||
|
mock_ws1.send_json = AsyncMock()
|
||||||
|
|
||||||
|
mock_ws2 = AsyncMock()
|
||||||
|
mock_ws2.accept = AsyncMock()
|
||||||
|
mock_ws2.send_json = AsyncMock()
|
||||||
|
|
||||||
|
# 连接
|
||||||
|
await manager.connect(mock_ws1, "client-1")
|
||||||
|
await manager.connect(mock_ws2, "client-2")
|
||||||
|
|
||||||
|
# 广播
|
||||||
|
await manager.broadcast({"type": "broadcast"})
|
||||||
|
mock_ws1.send_json.assert_called_once_with({"type": "broadcast"})
|
||||||
|
mock_ws2.send_json.assert_called_once_with({"type": "broadcast"})
|
||||||
|
|
||||||
|
|
||||||
|
class TestGatewayServer:
|
||||||
|
"""Gateway 服务器测试"""
|
||||||
|
|
||||||
|
def test_import_app(self):
|
||||||
|
"""测试导入应用"""
|
||||||
|
from minenasai.gateway.server import app
|
||||||
|
|
||||||
|
assert app is not None
|
||||||
|
assert app.title == "MineNASAI Gateway"
|
||||||
|
|
||||||
|
def test_import_endpoints(self):
|
||||||
|
"""测试导入端点函数"""
|
||||||
|
from minenasai.gateway.server import list_agents, list_sessions, root
|
||||||
|
|
||||||
|
assert callable(root)
|
||||||
|
assert callable(list_agents)
|
||||||
|
assert callable(list_sessions)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageTypes:
|
||||||
|
"""消息类型测试"""
|
||||||
|
|
||||||
|
def test_status_message(self):
|
||||||
|
"""测试状态消息"""
|
||||||
|
from minenasai.gateway.protocol import StatusMessage
|
||||||
|
|
||||||
|
msg = StatusMessage(status="thinking", message="处理中...")
|
||||||
|
assert msg.type == MessageType.STATUS
|
||||||
|
assert msg.status == "thinking"
|
||||||
|
assert msg.message == "处理中..."
|
||||||
|
|
||||||
|
def test_response_message(self):
|
||||||
|
"""测试响应消息"""
|
||||||
|
from minenasai.gateway.protocol import ResponseMessage
|
||||||
|
|
||||||
|
msg = ResponseMessage(content="Hello!", in_reply_to="msg-123")
|
||||||
|
assert msg.type == MessageType.RESPONSE
|
||||||
|
assert msg.content == "Hello!"
|
||||||
|
assert msg.in_reply_to == "msg-123"
|
||||||
|
|
||||||
|
def test_error_message(self):
|
||||||
|
"""测试错误消息"""
|
||||||
|
from minenasai.gateway.protocol import ErrorMessage
|
||||||
|
|
||||||
|
msg = ErrorMessage(message="Something went wrong", code="ERR_001")
|
||||||
|
assert msg.type == MessageType.ERROR
|
||||||
|
assert msg.message == "Something went wrong"
|
||||||
|
assert msg.code == "ERR_001"
|
||||||
|
|
||||||
|
def test_pong_message(self):
|
||||||
|
"""测试心跳响应消息"""
|
||||||
|
from minenasai.gateway.protocol import PongMessage
|
||||||
|
|
||||||
|
msg = PongMessage()
|
||||||
|
assert msg.type == MessageType.PONG
|
||||||
|
|||||||
@@ -2,9 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from minenasai.llm.base import Message, Provider, ToolCall, ToolDefinition
|
from minenasai.llm.base import (
|
||||||
|
LLMResponse,
|
||||||
|
Message,
|
||||||
|
Provider,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestProvider:
|
class TestProvider:
|
||||||
@@ -133,3 +142,208 @@ class TestLLMManager:
|
|||||||
providers = manager.get_available_providers()
|
providers = manager.get_available_providers()
|
||||||
# 可能为空,取决于环境变量
|
# 可能为空,取决于环境变量
|
||||||
assert isinstance(providers, list)
|
assert isinstance(providers, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAICompatClientMock:
|
||||||
|
"""OpenAI 兼容客户端 Mock 测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
"""创建测试客户端"""
|
||||||
|
from minenasai.llm.clients import OpenAICompatClient
|
||||||
|
|
||||||
|
return OpenAICompatClient(api_key="test-key", base_url="https://api.test.com/v1")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_mock(self, client):
|
||||||
|
"""测试聊天功能(Mock)"""
|
||||||
|
# Mock HTTP 响应
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I help you?",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"total_tokens": 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client") as mock_get_client:
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_client.return_value = mock_http_client
|
||||||
|
|
||||||
|
messages = [Message(role="user", content="Hello")]
|
||||||
|
response = await client.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResponse)
|
||||||
|
assert response.content == "Hello! How can I help you?"
|
||||||
|
assert response.model == "gpt-4o"
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
assert response.usage["total_tokens"] == 30
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_tools_mock(self, client):
|
||||||
|
"""测试带工具调用的聊天(Mock)"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "chatcmpl-456",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "read_file",
|
||||||
|
"arguments": '{"path": "/tmp/test.txt"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client") as mock_get_client:
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_client.return_value = mock_http_client
|
||||||
|
|
||||||
|
messages = [Message(role="user", content="Read the test file")]
|
||||||
|
tools = [
|
||||||
|
ToolDefinition(
|
||||||
|
name="read_file",
|
||||||
|
description="Read a file",
|
||||||
|
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response = await client.chat(messages, tools=tools)
|
||||||
|
|
||||||
|
assert response.tool_calls is not None
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
assert response.tool_calls[0].name == "read_file"
|
||||||
|
assert response.tool_calls[0].arguments == {"path": "/tmp/test.txt"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_client(self, client):
|
||||||
|
"""测试关闭客户端"""
|
||||||
|
# 创建 mock 客户端
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
client._client = mock_http_client
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
mock_http_client.aclose.assert_called_once()
|
||||||
|
assert client._client is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicClientMock:
|
||||||
|
"""Anthropic 客户端 Mock 测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
"""创建测试客户端"""
|
||||||
|
from minenasai.llm.clients import AnthropicClient
|
||||||
|
|
||||||
|
return AnthropicClient(api_key="test-key")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_mock(self, client):
|
||||||
|
"""测试聊天功能(Mock)"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "msg_123",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "Hello from Claude!"}],
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 15},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client") as mock_get_client:
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_get_client.return_value = mock_http_client
|
||||||
|
|
||||||
|
messages = [Message(role="user", content="Hello")]
|
||||||
|
response = await client.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResponse)
|
||||||
|
assert response.content == "Hello from Claude!"
|
||||||
|
assert response.provider == Provider.ANTHROPIC
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMResponse:
|
||||||
|
"""LLM 响应测试"""
|
||||||
|
|
||||||
|
def test_response_basic(self):
|
||||||
|
"""测试基本响应"""
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Test response",
|
||||||
|
model="gpt-4o",
|
||||||
|
provider=Provider.OPENAI,
|
||||||
|
)
|
||||||
|
assert response.content == "Test response"
|
||||||
|
assert response.model == "gpt-4o"
|
||||||
|
assert response.provider == Provider.OPENAI
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
|
||||||
|
def test_response_with_tool_calls(self):
|
||||||
|
"""测试带工具调用的响应"""
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(id="tc_1", name="read_file", arguments={"path": "/test"}),
|
||||||
|
ToolCall(id="tc_2", name="list_dir", arguments={"path": "/"}),
|
||||||
|
]
|
||||||
|
response = LLMResponse(
|
||||||
|
content="",
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
provider=Provider.ANTHROPIC,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="tool_use",
|
||||||
|
)
|
||||||
|
assert len(response.tool_calls) == 2
|
||||||
|
assert response.finish_reason == "tool_use"
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamChunk:
|
||||||
|
"""流式响应块测试"""
|
||||||
|
|
||||||
|
def test_chunk_basic(self):
|
||||||
|
"""测试基本响应块"""
|
||||||
|
chunk = StreamChunk(content="Hello")
|
||||||
|
assert chunk.content == "Hello"
|
||||||
|
assert chunk.is_final is False
|
||||||
|
|
||||||
|
def test_chunk_final(self):
|
||||||
|
"""测试最终响应块"""
|
||||||
|
chunk = StreamChunk(
|
||||||
|
content="",
|
||||||
|
is_final=True,
|
||||||
|
usage={"prompt_tokens": 10, "completion_tokens": 20},
|
||||||
|
)
|
||||||
|
assert chunk.is_final is True
|
||||||
|
assert chunk.usage is not None
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class TestToolRegistry:
|
|||||||
|
|
||||||
def test_tool_decorator(self):
|
def test_tool_decorator(self):
|
||||||
"""测试工具装饰器"""
|
"""测试工具装饰器"""
|
||||||
from minenasai.agent import tool, get_tool_registry
|
from minenasai.agent import get_tool_registry, tool
|
||||||
|
|
||||||
@tool(name="decorated_tool", description="装饰器测试")
|
@tool(name="decorated_tool", description="装饰器测试")
|
||||||
async def decorated_tool(param: str) -> dict:
|
async def decorated_tool(param: str) -> dict:
|
||||||
@@ -191,7 +191,7 @@ class TestToolRegistry:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_tool(self):
|
async def test_execute_tool(self):
|
||||||
"""测试执行工具"""
|
"""测试执行工具"""
|
||||||
from minenasai.agent import get_tool_registry, DangerLevel
|
from minenasai.agent import DangerLevel, get_tool_registry
|
||||||
|
|
||||||
registry = get_tool_registry()
|
registry = get_tool_registry()
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from minenasai.webtui.auth import AuthManager, AuthToken
|
from minenasai.webtui.auth import AuthManager, AuthToken
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user