feat: 更新模块文档,添加详细说明和使用示例

This commit is contained in:
锦麟 王
2026-02-05 15:43:08 +08:00
parent 23bf2cfaef
commit 64dc18f477
37 changed files with 1252 additions and 168 deletions

View File

@@ -1,3 +1,36 @@
"""MineNASAI - 基于NAS的智能个人AI助理"""
"""MineNASAI - 基于NAS的智能个人AI助理
MineNASAI 是一个部署在家用 NAS 上的智能 AI 助理系统,支持:
- 多渠道接入(飞书、企业微信等)
- 多 LLM 后端OpenAI、Anthropic、Gemini、国产模型等
- 智能任务路由与复杂度评估
- 安全的工具执行与权限管理
- Web 终端管理界面
模块结构:
- core: 核心基础设施(配置、日志、缓存、监控、数据库)
- agent: AI Agent 运行时和工具管理
- gateway: 消息网关和渠道接入
- llm: 多模型 LLM 客户端
- scheduler: 定时任务调度
- webtui: Web 管理界面
快速开始:
>>> from minenasai.core import get_settings, setup_logging
>>> from minenasai.llm import get_llm_manager
>>> from minenasai.agent import get_agent_runtime
>>>
>>> # 初始化
>>> settings = get_settings()
>>> setup_logging(settings)
>>>
>>> # 获取 LLM 管理器
>>> llm_manager = get_llm_manager()
>>>
>>> # 获取 Agent 运行时
>>> agent = await get_agent_runtime()
"""
__version__ = "0.1.0"
__author__ = "MineNASAI Team"
__license__ = "MIT"

View File

@@ -1,21 +1,53 @@
"""Agent 模块
"""Agent 模块 - AI Agent 运行时和工具管理
提供 Agent 运行时、工具管理、会话管理
本模块提供 AI Agent 的核心功能:
Agent 运行时 (runtime):
- AgentRuntime: Agent 执行引擎,支持多轮对话和工具调用
- get_agent_runtime(): 获取 Agent 运行时单例
权限管理 (permissions):
- PermissionManager: 工具权限管理器
- DangerLevel: 危险等级枚举SAFE/LOW/MEDIUM/HIGH/CRITICAL
- 支持细粒度的路径和参数权限控制
- 高危操作需要用户确认
工具注册 (tool_registry):
- ToolRegistry: 工具注册表
- @tool 装饰器: 快速注册工具函数
- register_builtin_tools(): 注册内置工具
内置工具 (tools):
- read_file: 读取文件内容
- list_directory: 列出目录文件
- python_eval: 安全的 Python 表达式求值
使用示例:
>>> from minenasai.agent import get_agent_runtime, tool
>>>
>>> # 注册自定义工具
>>> @tool(name="my_tool", description="自定义工具")
... async def my_tool(param: str) -> str:
... return f"处理: {param}"
>>>
>>> # 运行 Agent
>>> agent = await get_agent_runtime()
>>> response = await agent.run("帮我处理这个任务")
"""
from minenasai.agent.runtime import AgentRuntime, get_agent_runtime
from minenasai.agent.tools.basic import get_basic_tools
from minenasai.agent.permissions import (
DangerLevel,
PermissionManager,
get_permission_manager,
)
from minenasai.agent.runtime import AgentRuntime, get_agent_runtime
from minenasai.agent.tool_registry import (
ToolRegistry,
get_tool_registry,
register_builtin_tools,
tool,
)
from minenasai.agent.tools.basic import get_basic_tools
__all__ = [
"AgentRuntime",

View File

@@ -6,17 +6,21 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine
from enum import StrEnum
from typing import Any
from minenasai.core import get_logger, get_settings
logger = get_logger(__name__)
class DangerLevel(str, Enum):
"""危险等级"""
class DangerLevel(StrEnum):
"""危险等级
定义工具操作的风险级别,用于权限控制和确认流程。
"""
SAFE = "safe" # 只读操作,无风险
LOW = "low" # 低风险,写入工作目录
@@ -25,8 +29,11 @@ class DangerLevel(str, Enum):
CRITICAL = "critical" # 极高危,系统级操作
class ConfirmationStatus(str, Enum):
"""确认状态"""
class ConfirmationStatus(StrEnum):
"""确认状态
高危操作确认请求的状态。
"""
PENDING = "pending"
APPROVED = "approved"
@@ -168,14 +175,14 @@ class PermissionManager:
self,
tool_name: str,
params: dict[str, Any] | None = None,
user_id: str | None = None,
_user_id: str | None = None, # 预留:将来用于基于用户的权限检查
) -> tuple[bool, str]:
"""检查工具执行权限
Args:
tool_name: 工具名称
params: 执行参数
user_id: 用户 ID
_user_id: 用户 ID(预留参数,将来用于基于用户的权限检查)
Returns:
(是否允许, 原因)
@@ -375,7 +382,35 @@ _permission_manager: PermissionManager | None = None
def get_permission_manager() -> PermissionManager:
"""获取全局权限管理器"""
"""获取全局权限管理器单例
返回全局唯一的权限管理器实例,用于检查工具执行权限
和管理高危操作确认流程。
Returns:
PermissionManager: 全局权限管理器
Example:
>>> pm = get_permission_manager()
>>>
>>> # 检查权限
>>> allowed, reason = pm.check_permission(
... "write_file",
... {"path": "/tmp/test.txt"}
... )
>>>
>>> # 检查是否需要确认
>>> if pm.requires_confirmation("delete_file"):
... request_id = await pm.request_confirmation(
... "delete_file",
... {"path": "/important/file.txt"}
... )
... # 等待用户确认...
Note:
权限管理器会根据配置文件中的 security 设置
决定哪些危险等级的操作需要用户确认。
"""
global _permission_manager
if _permission_manager is None:
_permission_manager = PermissionManager()

View File

@@ -6,11 +6,12 @@
from __future__ import annotations
import time
from typing import Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from minenasai.core import get_audit_logger, get_logger, get_settings
from minenasai.core.session_store import get_session_store
from minenasai.llm import LLMManager, get_llm_manager
from minenasai.llm import get_llm_manager
from minenasai.llm.base import Message, Provider, ToolDefinition
logger = get_logger(__name__)

View File

@@ -6,11 +6,12 @@
from __future__ import annotations
import inspect
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, get_type_hints
from typing import Any, get_type_hints
from minenasai.core import get_logger
from minenasai.agent.permissions import DangerLevel, ToolPermission, get_permission_manager
from minenasai.core import get_logger
from minenasai.llm.base import ToolDefinition
logger = get_logger(__name__)

View File

@@ -1,9 +1,9 @@
"""内置工具集"""
from minenasai.agent.tools.basic import (
read_file_tool,
list_directory_tool,
python_eval_tool,
read_file_tool,
)
__all__ = [

View File

@@ -5,8 +5,6 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
from minenasai.core import get_logger

View File

@@ -1,6 +1,42 @@
"""核心模块
"""核心模块 - 提供系统基础设施
提供配置管理、日志系统、数据库、监控、缓存等基础功能
本模块包含 MineNASAI 的核心基础功能
配置管理 (config):
- Settings: 全局配置数据类
- get_settings(): 获取配置单例
- load_config(): 从文件加载配置
日志系统 (logging):
- setup_logging(): 初始化日志系统
- get_logger(): 获取模块日志记录器
- AuditLogger: 审计日志记录器
监控系统 (monitoring):
- SystemMetrics: 系统指标收集
- HealthChecker: 健康检查器
- setup_monitoring(): 初始化监控
缓存系统 (cache):
- MemoryCache: 内存缓存,支持 TTL 和 LRU 淘汰
- RateLimiter: 令牌桶限流器
数据库 (database):
- Database: SQLite 异步数据库封装
- 支持会话、消息、任务的持久化
使用示例:
>>> from minenasai.core import get_settings, setup_logging, get_logger
>>>
>>> # 初始化配置和日志
>>> settings = get_settings()
>>> setup_logging(settings)
>>> logger = get_logger(__name__)
>>>
>>> # 使用缓存
>>> from minenasai.core import get_response_cache
>>> cache = get_response_cache()
>>> await cache.set("key", "value", ttl=300)
"""
from minenasai.core.cache import (

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
import asyncio
import contextlib
import hashlib
import time
from dataclasses import dataclass, field
@@ -71,10 +72,8 @@ class MemoryCache(Generic[T]):
"""停止后台清理任务"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self) -> None:

View File

@@ -296,7 +296,25 @@ _settings: Settings | None = None
def get_settings() -> Settings:
"""获取全局配置"""
"""获取全局配置
返回全局唯一的配置实例。首次调用时从配置文件加载,
后续调用返回缓存的实例。
Returns:
Settings: 全局配置对象
Example:
>>> settings = get_settings()
>>> print(settings.app.name)
MineNASAI
>>> print(settings.llm.default_model)
claude-sonnet-4-20250514
Note:
配置加载优先级:环境变量 > 配置文件 > 默认值
使用 reset_settings() 可重置配置(仅用于测试)
"""
global _settings
if _settings is None:
_settings = load_config()
@@ -304,6 +322,13 @@ def get_settings() -> Settings:
def reset_settings() -> None:
"""重置全局配置(用于测试)"""
"""重置全局配置单例
清除缓存的配置实例,下次调用 get_settings() 时将重新加载配置。
主要用于测试场景,生产环境慎用。
Warning:
此函数会清除所有已加载的配置,可能影响正在运行的组件。
"""
global _settings
_settings = None

View File

@@ -11,10 +11,11 @@ from __future__ import annotations
import asyncio
import time
import traceback
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Coroutine
from typing import Any
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
@@ -161,7 +162,7 @@ class HealthChecker:
result.latency_ms = (time.time() - start_time) * 1000
self._results[name] = result
return result
except asyncio.TimeoutError:
except TimeoutError:
result = ComponentHealth(
name=name,
status=HealthStatus.UNHEALTHY,

View File

@@ -7,8 +7,9 @@ from __future__ import annotations
import json
import time
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Iterator
from typing import Any
from minenasai.core.config import expand_path

View File

@@ -1,6 +1,51 @@
"""Gateway 模块
"""Gateway 模块 - 消息网关和渠道接入
提供 WebSocket 服务、通讯渠道接入、智能路由
本模块提供消息网关服务,支持多渠道消息接入:
WebSocket 服务 (server):
- 提供实时双向通信
- 支持多客户端连接管理
- 心跳检测和自动重连
渠道接入 (channels):
- FeishuChannel: 飞书机器人接入
- WeWorkChannel: 企业微信机器人接入
- 统一的消息格式转换
智能路由 (router):
- SmartRouter: 智能任务路由器
- 自动评估消息复杂度
- 根据复杂度选择处理策略
协议定义 (protocol):
- ChatMessage: 聊天消息模型
- MessageType: 消息类型枚举
- ChannelType: 渠道类型枚举
使用示例:
>>> from minenasai.gateway.router import SmartRouter
>>> from minenasai.gateway.protocol.schema import ChatMessage
>>>
>>> # 创建路由器
>>> router = SmartRouter()
>>>
>>> # 评估任务复杂度
>>> message = ChatMessage(content="帮我写一个 Python 爬虫")
>>> complexity = router.route(message)
"""
__all__ = []
from minenasai.gateway.protocol.schema import (
ChannelType,
ChatMessage,
MessageType,
TaskComplexity,
)
from minenasai.gateway.router import SmartRouter
__all__ = [
"ChatMessage",
"MessageType",
"ChannelType",
"TaskComplexity",
"SmartRouter",
]

View File

@@ -5,8 +5,6 @@
from __future__ import annotations
import hashlib
import hmac
import time
from typing import Any
@@ -33,8 +31,12 @@ class FeishuChannel(BaseChannel):
self._tenant_access_token: str | None = None
self._token_expires: float = 0
async def verify_signature(self, request: Any) -> bool:
"""验证飞书签名"""
async def verify_signature(self, _request: Any) -> bool:
"""验证飞书签名
Args:
_request: HTTP 请求对象(待实现时使用)
"""
# TODO: 实现签名验证
return True
@@ -142,10 +144,10 @@ class FeishuChannel(BaseChannel):
# 判断是用户还是群聊
chat_id = kwargs.get("chat_id")
if chat_id:
url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id"
url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id"
receive_id = chat_id
else:
url = f"https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id"
url = "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=user_id"
receive_id = peer_id
import json

View File

@@ -5,7 +5,6 @@
from __future__ import annotations
import hashlib
import time
from typing import Any
@@ -33,10 +32,13 @@ class WeworkChannel(BaseChannel):
self._access_token: str | None = None
self._token_expires: float = 0
async def verify_signature(self, request: Any) -> bool:
async def verify_signature(self, _request: Any) -> bool:
"""验证企业微信签名
企业微信使用 SHA1 签名验证
企业微信使用 SHA1 签名验证
Args:
_request: HTTP 请求对象(待实现时使用)
"""
# TODO: 实现签名验证
# signature = sorted([self.token, timestamp, nonce, echostr])
@@ -105,9 +107,16 @@ class WeworkChannel(BaseChannel):
peer_id: str,
content: str,
message_type: str = "text",
**kwargs: Any,
**_kwargs: Any, # 预留:将来支持更多消息参数
) -> bool:
"""发送企业微信消息"""
"""发送企业微信消息
Args:
peer_id: 接收者用户 ID
content: 消息内容
message_type: 消息类型
**_kwargs: 预留参数,将来支持更多消息选项
"""
access_token = await self._get_access_token()
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"

View File

@@ -5,14 +5,17 @@
from __future__ import annotations
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class MessageType(str, Enum):
"""消息类型"""
class MessageType(StrEnum):
"""消息类型
定义 WebSocket 通信中的消息类型。
"""
# 客户端 -> 服务器
CHAT = "chat" # 普通聊天消息
@@ -32,8 +35,11 @@ class MessageType(str, Enum):
CONFIRM = "confirm" # 确认请求/响应
class ChannelType(str, Enum):
"""渠道类型"""
class ChannelType(StrEnum):
"""渠道类型
定义消息来源的渠道类型。
"""
WEWORK = "wework"
FEISHU = "feishu"
@@ -41,8 +47,11 @@ class ChannelType(str, Enum):
WEBSOCKET = "websocket"
class TaskComplexity(str, Enum):
"""任务复杂度"""
class TaskComplexity(StrEnum):
"""任务复杂度
用于智能路由,决定任务的处理方式。
"""
SIMPLE = "simple" # 简单查询,直接回复
MEDIUM = "medium" # 中等任务,需要工具

View File

@@ -6,8 +6,9 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
from typing import Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
@@ -152,10 +153,16 @@ async def list_sessions(agent_id: str | None = None) -> dict[str, Any]:
async def handle_chat_message(
websocket: WebSocket,
client_id: str,
_client_id: str, # 预留:将来用于客户端标识和会话管理
message: ChatMessage,
) -> None:
"""处理聊天消息"""
"""处理聊天消息
Args:
websocket: WebSocket 连接
_client_id: 客户端 ID预留将来用于会话管理
message: 聊天消息
"""
router = get_router()
# 发送思考状态

View File

@@ -1,6 +1,44 @@
"""LLM 多模型客户端模块
"""LLM 多模型客户端模块 - 统一的 AI API 接口
支持多个 AI API 提供商的统一接口
本模块提供统一的 LLM 客户端接口,支持多个 AI 提供商:
支持的提供商:
- OpenAI (GPT-4, GPT-3.5)
- Anthropic (Claude 3)
- Google (Gemini)
- DeepSeek
- Moonshot (月之暗面)
- MiniMax
- Zhipu (智谱)
核心类:
- BaseLLMClient: LLM 客户端基类
- LLMManager: 多模型管理器,支持负载均衡和故障转移
- Message: 统一消息格式
- LLMResponse: 统一响应格式
- ToolCall: 工具调用定义
特性:
- 统一的 API 接口
- 自动重试和错误处理
- 流式响应支持
- 工具调用 (Function Calling) 支持
- 响应缓存
使用示例:
>>> from minenasai.llm import get_llm_manager, Message
>>>
>>> # 获取 LLM 管理器
>>> manager = get_llm_manager()
>>>
>>> # 发送消息
>>> messages = [Message(role="user", content="你好")]
>>> response = await manager.chat(messages)
>>> print(response.content)
>>>
>>> # 流式响应
>>> async for chunk in manager.stream(messages):
... print(chunk.content, end="")
"""
from minenasai.llm.base import BaseLLMClient, LLMResponse, Message, ToolCall

View File

@@ -6,13 +6,17 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator
from enum import StrEnum
from typing import Any
class Provider(str, Enum):
"""LLM 提供商"""
class Provider(StrEnum):
"""LLM 提供商
支持的 AI API 提供商枚举。
"""
ANTHROPIC = "anthropic" # Claude
OPENAI = "openai" # GPT

View File

@@ -1,12 +1,12 @@
"""LLM 客户端实现"""
from minenasai.llm.clients.anthropic import AnthropicClient
from minenasai.llm.clients.openai_compat import OpenAICompatClient
from minenasai.llm.clients.deepseek import DeepSeekClient
from minenasai.llm.clients.zhipu import ZhipuClient
from minenasai.llm.clients.gemini import GeminiClient
from minenasai.llm.clients.minimax import MiniMaxClient
from minenasai.llm.clients.moonshot import MoonshotClient
from minenasai.llm.clients.gemini import GeminiClient
from minenasai.llm.clients.openai_compat import OpenAICompatClient
from minenasai.llm.clients.zhipu import ZhipuClient
__all__ = [
"AnthropicClient",

View File

@@ -3,7 +3,8 @@
from __future__ import annotations
import json
from typing import Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
import httpx
@@ -74,7 +75,7 @@ class AnthropicClient(BaseLLMClient):
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
**_kwargs: Any, # 预留:基类接口扩展参数
) -> LLMResponse:
"""发送聊天请求"""
client = await self._get_client()
@@ -142,7 +143,7 @@ class AnthropicClient(BaseLLMClient):
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
**_kwargs: Any, # 预留:基类接口扩展参数
) -> AsyncIterator[StreamChunk]:
"""流式聊天"""
client = await self._get_client()

View File

@@ -6,7 +6,8 @@ Gemini API 使用独特的接口格式
from __future__ import annotations
import json
from typing import Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
import httpx
@@ -76,7 +77,7 @@ class GeminiClient(BaseLLMClient):
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
**_kwargs: Any, # 预留:基类接口扩展参数
) -> LLMResponse:
"""发送聊天请求"""
client = await self._get_client()
@@ -164,10 +165,10 @@ class GeminiClient(BaseLLMClient):
system: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
_tools: list[ToolDefinition] | None = None, # 暂不支持流式工具调用
**_kwargs: Any, # 预留:基类接口扩展参数
) -> AsyncIterator[StreamChunk]:
"""流式聊天"""
"""流式聊天(暂不支持工具调用)"""
client = await self._get_client()
model = model or self.default_model

View File

@@ -5,8 +5,6 @@ MiniMax 使用 OpenAI 兼容接口
from __future__ import annotations
from typing import Any
import httpx
from minenasai.llm.base import Provider

View File

@@ -6,7 +6,8 @@
from __future__ import annotations
import json
from typing import Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
import httpx
@@ -87,7 +88,7 @@ class OpenAICompatClient(BaseLLMClient):
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
**_kwargs: Any, # 预留:基类接口扩展参数
) -> LLMResponse:
"""发送聊天请求"""
client = await self._get_client()
@@ -154,7 +155,7 @@ class OpenAICompatClient(BaseLLMClient):
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
**_kwargs: Any, # 预留:基类接口扩展参数
) -> AsyncIterator[StreamChunk]:
"""流式聊天"""
client = await self._get_client()

View File

@@ -5,7 +5,8 @@
from __future__ import annotations
from typing import Any, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from minenasai.core import get_logger, get_settings
from minenasai.llm.base import (
@@ -142,10 +143,10 @@ class LLMManager:
client = self._create_client(provider)
if client:
self._clients[provider] = client
logger.info(f"LLM 客户端已初始化", provider=provider.display_name)
logger.info("LLM 客户端已初始化", provider=provider.display_name)
self._initialized = True
logger.info(f"LLM Manager 初始化完成", providers=list(self._clients.keys()))
logger.info("LLM Manager 初始化完成", providers=list(self._clients.keys()))
def get_client(self, provider: Provider) -> BaseLLMClient | None:
"""获取指定提供商的客户端"""
@@ -334,7 +335,30 @@ _llm_manager: LLMManager | None = None
def get_llm_manager() -> LLMManager:
"""获取全局 LLM 管理器"""
"""获取全局 LLM 管理器单例
返回全局唯一的 LLM 管理器实例。管理器会根据配置自动初始化
可用的 LLM 客户端。
Returns:
LLMManager: 全局 LLM 管理器
Example:
>>> manager = get_llm_manager()
>>> # 发送消息
>>> response = await manager.chat([
... Message(role="user", content="你好")
... ])
>>> print(response.content)
>>>
>>> # 流式响应
>>> async for chunk in manager.stream(messages):
... print(chunk.content, end="")
Note:
首次调用时会自动初始化所有配置了 API Key 的客户端。
支持的提供商取决于环境变量或配置文件中的 API Key 设置。
"""
global _llm_manager
if _llm_manager is None:
_llm_manager = LLMManager()

View File

@@ -1,6 +1,43 @@
"""定时任务调度模块
"""定时任务调度模块 - Cron 任务调度
提供 Cron 任务调度功能
本模块提供类似 Linux cron 的定时任务调度功能
核心类:
- CronScheduler: 调度器,管理所有定时任务
- CronJob: 定时任务定义
Cron 表达式:
支持标准 5 字段 cron 表达式:分 时 日 月 周
- * : 任意值
- */n : 每隔 n
- n-m : 范围
- n,m : 列表
预设表达式:
- @hourly : 每小时
- @daily : 每天
- @weekly : 每周
- @monthly : 每月
使用示例:
>>> from minenasai.scheduler import get_scheduler
>>>
>>> # 获取调度器
>>> scheduler = get_scheduler()
>>>
>>> # 添加定时任务
>>> async def my_task():
... print("任务执行中...")
>>>
>>> scheduler.add_job(
... job_id="my_job",
... cron_expr="0 9 * * *", # 每天 9:00
... func=my_task,
... description="每日任务"
... )
>>>
>>> # 启动调度器
>>> await scheduler.start()
"""
from minenasai.scheduler.cron import CronJob, CronScheduler, get_scheduler

View File

@@ -6,21 +6,24 @@
from __future__ import annotations
import asyncio
import re
import contextlib
import time
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Coroutine
from enum import StrEnum
from typing import Any
from minenasai.core import get_logger
from minenasai.core.database import get_database
logger = get_logger(__name__)
class JobStatus(str, Enum):
"""任务状态"""
class JobStatus(StrEnum):
"""任务状态
定时任务的执行状态。
"""
PENDING = "pending"
RUNNING = "running"
@@ -283,10 +286,8 @@ class CronScheduler:
self._running = False
if self._task:
self._task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self._task
except asyncio.CancelledError:
pass
logger.info("定时调度器已停止")
async def _run_loop(self) -> None:

View File

@@ -1,6 +1,36 @@
"""Web TUI 模块
"""Web TUI 模块 - Web 管理界面
提供 Web 终端界面、SSH 连接管理
本模块提供基于 Web 终端管理界面:
认证管理 (auth):
- AuthManager: 认证管理器
- AuthToken: 认证令牌
- 支持 JWT 令牌认证
- 令牌刷新和撤销
SSH 管理 (ssh_manager):
- SSHManager: SSH 连接管理器
- SSHSession: SSH 会话封装
- 支持多会话管理
- WebSocket 实时终端
Web 服务 (server):
- FastAPI 应用
- 静态文件服务
- WebSocket 终端
- RESTful API
使用示例:
>>> from minenasai.webtui import get_auth_manager
>>>
>>> # 获取认证管理器
>>> auth = get_auth_manager()
>>>
>>> # 生成令牌
>>> token = auth.generate_token(user_id="admin")
>>>
>>> # 验证令牌
>>> user_info = auth.verify_token(token.token)
"""
from minenasai.webtui.auth import AuthManager, AuthToken, get_auth_manager

View File

@@ -7,13 +7,14 @@ from __future__ import annotations
import asyncio
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncGenerator
from typing import Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from minenasai.core import get_logger, get_settings, setup_logging
@@ -95,8 +96,12 @@ manager = ConnectionManager()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""应用生命周期管理"""
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""应用生命周期管理
Args:
_app: FastAPI 应用实例lifespan 标准签名要求)
"""
settings = get_settings()
setup_logging(settings.logging)
logger.info("Web TUI 服务启动", port=settings.webtui.port)

View File

@@ -6,10 +6,11 @@
from __future__ import annotations
import asyncio
import contextlib
import os
import time
from pathlib import Path
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
import paramiko
@@ -128,10 +129,8 @@ class SSHSession:
if self._read_task:
self._read_task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self._read_task
except asyncio.CancelledError:
pass
if self.channel:
self.channel.close()

View File

@@ -4,8 +4,6 @@ from __future__ import annotations
import json
import pytest
from minenasai.core import Settings, get_settings, load_config, reset_settings
from minenasai.core.config import expand_path

345
tests/test_database.py Normal file
View File

@@ -0,0 +1,345 @@
"""数据库模块测试
测试 SQLite 异步数据库操作
"""
import pytest
from minenasai.core.database import Database
class TestDatabase:
"""数据库基本操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_connect_and_close(self, tmp_path):
"""测试数据库连接和关闭"""
db_path = tmp_path / "test.db"
db = Database(db_path)
# 连接前访问 conn 应该报错
with pytest.raises(RuntimeError, match="数据库未连接"):
_ = db.conn
# 连接
await db.connect()
assert db.conn is not None
# 关闭
await db.close()
assert db._conn is None
class TestAgentOperations:
"""Agent 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_create_agent(self, db):
"""测试创建 Agent"""
agent = await db.create_agent(
agent_id="test-agent-1",
name="Test Agent",
workspace_path="/tmp/workspace",
)
assert agent["id"] == "test-agent-1"
assert agent["name"] == "Test Agent"
assert agent["workspace_path"] == "/tmp/workspace"
assert agent["model"] == "claude-sonnet-4-20250514"
assert agent["sandbox_mode"] == "workspace"
assert "created_at" in agent
assert "updated_at" in agent
@pytest.mark.asyncio
async def test_create_agent_custom_model(self, db):
"""测试使用自定义模型创建 Agent"""
agent = await db.create_agent(
agent_id="test-agent-2",
name="Custom Agent",
workspace_path="/tmp/workspace",
model="gpt-4",
sandbox_mode="strict",
)
assert agent["model"] == "gpt-4"
assert agent["sandbox_mode"] == "strict"
@pytest.mark.asyncio
async def test_get_agent(self, db):
"""测试获取 Agent"""
# 创建 Agent
await db.create_agent(
agent_id="test-agent-3",
name="Get Test Agent",
workspace_path="/tmp/workspace",
)
# 获取存在的 Agent
agent = await db.get_agent("test-agent-3")
assert agent is not None
assert agent["name"] == "Get Test Agent"
# 获取不存在的 Agent
agent = await db.get_agent("nonexistent")
assert agent is None
@pytest.mark.asyncio
async def test_list_agents(self, db):
"""测试列出所有 Agent"""
# 初始为空
agents = await db.list_agents()
assert len(agents) == 0
# 创建多个 Agent
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
# 列出所有
agents = await db.list_agents()
assert len(agents) == 3
class TestSessionOperations:
"""Session 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库并添加测试 Agent"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
yield db
await db.close()
@pytest.mark.asyncio
async def test_create_session(self, db):
"""测试创建会话"""
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
peer_id="user-123",
)
assert session["agent_id"] == "test-agent"
assert session["channel"] == "websocket"
assert session["peer_id"] == "user-123"
assert session["status"] == "active"
assert "session_key" in session
@pytest.mark.asyncio
async def test_create_session_with_metadata(self, db):
"""测试创建带元数据的会话"""
metadata = {"client": "web", "version": "1.0"}
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
metadata=metadata,
)
assert session["metadata"] == metadata
@pytest.mark.asyncio
async def test_get_session(self, db):
"""测试获取会话"""
# 创建会话
created = await db.create_session(
agent_id="test-agent",
channel="websocket",
metadata={"test": "data"},
)
# 获取会话
session = await db.get_session(created["id"])
assert session is not None
assert session["id"] == created["id"]
assert session["metadata"] == {"test": "data"}
# 获取不存在的会话
session = await db.get_session("nonexistent-id")
assert session is None
@pytest.mark.asyncio
async def test_update_session_status(self, db):
"""测试更新会话状态"""
# 创建会话
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
)
assert session["status"] == "active"
# 更新状态
await db.update_session_status(session["id"], "closed")
# 验证
updated = await db.get_session(session["id"])
assert updated["status"] == "closed"
@pytest.mark.asyncio
async def test_list_active_sessions(self, db):
"""测试列出活跃会话"""
# 创建多个会话
session1 = await db.create_session("test-agent", "websocket")
session2 = await db.create_session("test-agent", "feishu")
await db.create_session("test-agent", "wework")
# 关闭一个会话
await db.update_session_status(session2["id"], "closed")
# 列出所有活跃会话
active = await db.list_active_sessions()
assert len(active) == 2
# 按 agent_id 过滤
active = await db.list_active_sessions("test-agent")
assert len(active) == 2
class TestMessageOperations:
"""Message 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库并添加测试会话"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
session = await db.create_session("test-agent", "websocket")
db._test_session_id = session["id"]
yield db
await db.close()
@pytest.mark.asyncio
async def test_add_message(self, db):
"""测试添加消息"""
session_id = db._test_session_id
message = await db.add_message(
session_id=session_id,
role="user",
content="Hello, world!",
)
assert message["session_id"] == session_id
assert message["role"] == "user"
assert message["content"] == "Hello, world!"
assert message["tokens_used"] == 0
@pytest.mark.asyncio
async def test_add_message_with_tool_calls(self, db):
"""测试添加带工具调用的消息"""
session_id = db._test_session_id
tool_calls = [
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
]
message = await db.add_message(
session_id=session_id,
role="assistant",
content=None,
tool_calls=tool_calls,
tokens_used=100,
)
assert message["tool_calls"] == tool_calls
assert message["tokens_used"] == 100
@pytest.mark.asyncio
async def test_get_messages(self, db):
"""测试获取会话消息"""
session_id = db._test_session_id
# 添加多条消息
await db.add_message(session_id, "user", "Message 1")
await db.add_message(session_id, "assistant", "Response 1")
await db.add_message(session_id, "user", "Message 2")
await db.add_message(session_id, "assistant", "Response 2")
# 获取所有消息
messages = await db.get_messages(session_id)
assert len(messages) == 4
# 验证角色交替
roles = [m["role"] for m in messages]
assert roles.count("user") == 2
assert roles.count("assistant") == 2
@pytest.mark.asyncio
async def test_get_messages_with_limit(self, db):
"""测试获取消息数量限制"""
session_id = db._test_session_id
# 添加多条消息
for i in range(10):
await db.add_message(session_id, "user", f"Message {i}")
# 限制返回数量
messages = await db.get_messages(session_id, limit=5)
assert len(messages) == 5
class TestAuditLog:
"""审计日志测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_add_audit_log(self, db):
"""测试添加审计日志"""
# 添加日志不应该报错
await db.add_audit_log(
agent_id="test-agent",
tool_name="read_file",
danger_level="safe",
params={"path": "/tmp/test.txt"},
result="success",
duration_ms=50,
)
@pytest.mark.asyncio
async def test_add_audit_log_minimal(self, db):
"""测试添加最小审计日志"""
await db.add_audit_log(
agent_id=None,
tool_name="python_eval",
danger_level="low",
)
class TestGlobalDatabase:
"""全局数据库实例测试"""
@pytest.mark.asyncio
async def test_import_functions(self):
"""测试导入全局函数"""
from minenasai.core.database import close_database, get_database
assert callable(get_database)
assert callable(close_database)

View File

@@ -142,3 +142,159 @@ class TestRouterEdgeCases:
result = self.router.evaluate("查看 /tmp/test.txt 文件内容")
assert result["complexity"] in [TaskComplexity.SIMPLE, TaskComplexity.MEDIUM]
class TestConnectionManager:
"""WebSocket 连接管理器测试"""
def test_import_manager(self):
"""测试导入连接管理器"""
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
assert manager.active_connections == {}
@pytest.mark.asyncio
async def test_connect_and_disconnect(self):
"""测试连接和断开"""
from unittest.mock import AsyncMock, MagicMock
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
# Mock WebSocket
mock_ws = AsyncMock()
mock_ws.accept = AsyncMock()
# 连接
await manager.connect(mock_ws, "client-1")
assert "client-1" in manager.active_connections
mock_ws.accept.assert_called_once()
# 断开
manager.disconnect("client-1")
assert "client-1" not in manager.active_connections
@pytest.mark.asyncio
async def test_disconnect_nonexistent(self):
"""测试断开不存在的连接"""
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
# 不应该抛出异常
manager.disconnect("nonexistent")
@pytest.mark.asyncio
async def test_send_message(self):
"""测试发送消息"""
from unittest.mock import AsyncMock
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
# Mock WebSocket
mock_ws = AsyncMock()
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
# 连接
await manager.connect(mock_ws, "client-1")
# 发送消息
await manager.send_message("client-1", {"type": "test"})
mock_ws.send_json.assert_called_once_with({"type": "test"})
@pytest.mark.asyncio
async def test_send_message_to_nonexistent(self):
"""测试发送消息给不存在的客户端"""
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
# 不应该抛出异常
await manager.send_message("nonexistent", {"type": "test"})
@pytest.mark.asyncio
async def test_broadcast(self):
"""测试广播消息"""
from unittest.mock import AsyncMock
from minenasai.gateway.server import ConnectionManager
manager = ConnectionManager()
# Mock 多个 WebSocket
mock_ws1 = AsyncMock()
mock_ws1.accept = AsyncMock()
mock_ws1.send_json = AsyncMock()
mock_ws2 = AsyncMock()
mock_ws2.accept = AsyncMock()
mock_ws2.send_json = AsyncMock()
# 连接
await manager.connect(mock_ws1, "client-1")
await manager.connect(mock_ws2, "client-2")
# 广播
await manager.broadcast({"type": "broadcast"})
mock_ws1.send_json.assert_called_once_with({"type": "broadcast"})
mock_ws2.send_json.assert_called_once_with({"type": "broadcast"})
class TestGatewayServer:
"""Gateway 服务器测试"""
def test_import_app(self):
"""测试导入应用"""
from minenasai.gateway.server import app
assert app is not None
assert app.title == "MineNASAI Gateway"
def test_import_endpoints(self):
"""测试导入端点函数"""
from minenasai.gateway.server import list_agents, list_sessions, root
assert callable(root)
assert callable(list_agents)
assert callable(list_sessions)
class TestMessageTypes:
"""消息类型测试"""
def test_status_message(self):
"""测试状态消息"""
from minenasai.gateway.protocol import StatusMessage
msg = StatusMessage(status="thinking", message="处理中...")
assert msg.type == MessageType.STATUS
assert msg.status == "thinking"
assert msg.message == "处理中..."
def test_response_message(self):
"""测试响应消息"""
from minenasai.gateway.protocol import ResponseMessage
msg = ResponseMessage(content="Hello!", in_reply_to="msg-123")
assert msg.type == MessageType.RESPONSE
assert msg.content == "Hello!"
assert msg.in_reply_to == "msg-123"
def test_error_message(self):
"""测试错误消息"""
from minenasai.gateway.protocol import ErrorMessage
msg = ErrorMessage(message="Something went wrong", code="ERR_001")
assert msg.type == MessageType.ERROR
assert msg.message == "Something went wrong"
assert msg.code == "ERR_001"
def test_pong_message(self):
"""测试心跳响应消息"""
from minenasai.gateway.protocol import PongMessage
msg = PongMessage()
assert msg.type == MessageType.PONG

View File

@@ -2,9 +2,18 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from minenasai.llm.base import Message, Provider, ToolCall, ToolDefinition
from minenasai.llm.base import (
LLMResponse,
Message,
Provider,
StreamChunk,
ToolCall,
ToolDefinition,
)
class TestProvider:
@@ -133,3 +142,208 @@ class TestLLMManager:
providers = manager.get_available_providers()
# 可能为空,取决于环境变量
assert isinstance(providers, list)
class TestOpenAICompatClientMock:
"""OpenAI 兼容客户端 Mock 测试"""
@pytest.fixture
def client(self):
"""创建测试客户端"""
from minenasai.llm.clients import OpenAICompatClient
return OpenAICompatClient(api_key="test-key", base_url="https://api.test.com/v1")
@pytest.mark.asyncio
async def test_chat_mock(self, client):
"""测试聊天功能Mock"""
# Mock HTTP 响应
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_http_client
messages = [Message(role="user", content="Hello")]
response = await client.chat(messages)
assert isinstance(response, LLMResponse)
assert response.content == "Hello! How can I help you?"
assert response.model == "gpt-4o"
assert response.finish_reason == "stop"
assert response.usage["total_tokens"] == 30
@pytest.mark.asyncio
async def test_chat_with_tools_mock(self, client):
"""测试带工具调用的聊天Mock"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "chatcmpl-456",
"model": "gpt-4o",
"choices": [
{
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "read_file",
"arguments": '{"path": "/tmp/test.txt"}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_http_client
messages = [Message(role="user", content="Read the test file")]
tools = [
ToolDefinition(
name="read_file",
description="Read a file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
)
]
response = await client.chat(messages, tools=tools)
assert response.tool_calls is not None
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "read_file"
assert response.tool_calls[0].arguments == {"path": "/tmp/test.txt"}
@pytest.mark.asyncio
async def test_close_client(self, client):
"""测试关闭客户端"""
# 创建 mock 客户端
mock_http_client = AsyncMock()
client._client = mock_http_client
await client.close()
mock_http_client.aclose.assert_called_once()
assert client._client is None
class TestAnthropicClientMock:
"""Anthropic 客户端 Mock 测试"""
@pytest.fixture
def client(self):
"""创建测试客户端"""
from minenasai.llm.clients import AnthropicClient
return AnthropicClient(api_key="test-key")
@pytest.mark.asyncio
async def test_chat_mock(self, client):
"""测试聊天功能Mock"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello from Claude!"}],
"model": "claude-sonnet-4-20250514",
"stop_reason": "end_turn",
"usage": {"input_tokens": 10, "output_tokens": 15},
}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_http_client
messages = [Message(role="user", content="Hello")]
response = await client.chat(messages)
assert isinstance(response, LLMResponse)
assert response.content == "Hello from Claude!"
assert response.provider == Provider.ANTHROPIC
class TestLLMResponse:
"""LLM 响应测试"""
def test_response_basic(self):
"""测试基本响应"""
response = LLMResponse(
content="Test response",
model="gpt-4o",
provider=Provider.OPENAI,
)
assert response.content == "Test response"
assert response.model == "gpt-4o"
assert response.provider == Provider.OPENAI
assert response.finish_reason == "stop"
def test_response_with_tool_calls(self):
"""测试带工具调用的响应"""
tool_calls = [
ToolCall(id="tc_1", name="read_file", arguments={"path": "/test"}),
ToolCall(id="tc_2", name="list_dir", arguments={"path": "/"}),
]
response = LLMResponse(
content="",
model="claude-sonnet-4-20250514",
provider=Provider.ANTHROPIC,
tool_calls=tool_calls,
finish_reason="tool_use",
)
assert len(response.tool_calls) == 2
assert response.finish_reason == "tool_use"
class TestStreamChunk:
"""流式响应块测试"""
def test_chunk_basic(self):
"""测试基本响应块"""
chunk = StreamChunk(content="Hello")
assert chunk.content == "Hello"
assert chunk.is_final is False
def test_chunk_final(self):
"""测试最终响应块"""
chunk = StreamChunk(
content="",
is_final=True,
usage={"prompt_tokens": 10, "completion_tokens": 20},
)
assert chunk.is_final is True
assert chunk.usage is not None

View File

@@ -30,7 +30,7 @@ class TestToolPermission:
name="test_tool",
danger_level=DangerLevel.SAFE,
)
assert perm.requires_confirmation is False
assert perm.rate_limit is None
@@ -45,7 +45,7 @@ class TestPermissionManager:
def test_get_default_permission(self):
"""测试获取默认权限"""
perm = self.manager.get_permission("read_file")
assert perm is not None
assert perm.danger_level == DangerLevel.SAFE
@@ -57,7 +57,7 @@ class TestPermissionManager:
description="自定义工具",
)
self.manager.register_tool(perm)
result = self.manager.get_permission("custom_tool")
assert result is not None
assert result.danger_level == DangerLevel.MEDIUM
@@ -65,13 +65,13 @@ class TestPermissionManager:
def test_check_permission_allowed(self):
"""测试权限检查 - 允许"""
allowed, reason = self.manager.check_permission("read_file")
assert allowed is True
def test_check_permission_unknown_tool(self):
"""测试权限检查 - 未知工具"""
allowed, reason = self.manager.check_permission("unknown_tool")
assert allowed is False
assert "未知工具" in reason
@@ -83,12 +83,12 @@ class TestPermissionManager:
denied_paths=["/etc/", "/root/"],
)
self.manager.register_tool(perm)
allowed, reason = self.manager.check_permission(
"restricted_read",
params={"path": "/etc/passwd"},
)
assert allowed is False
assert "禁止访问" in reason
@@ -96,7 +96,7 @@ class TestPermissionManager:
"""测试确认要求 - 按等级"""
# HIGH 级别需要确认
assert self.manager.requires_confirmation("delete_file") is True
# SAFE 级别不需要确认
assert self.manager.requires_confirmation("read_file") is False
@@ -108,7 +108,7 @@ class TestPermissionManager:
requires_confirmation=True,
)
self.manager.register_tool(perm)
assert self.manager.requires_confirmation("explicit_confirm") is True
@pytest.mark.asyncio
@@ -119,9 +119,9 @@ class TestPermissionManager:
tool_name="delete_file",
params={"path": "/test.txt"},
)
assert request.status == ConfirmationStatus.PENDING
# 批准
self.manager.approve_confirmation("req-1")
assert request.status == ConfirmationStatus.APPROVED
@@ -134,7 +134,7 @@ class TestPermissionManager:
tool_name="delete_file",
params={"path": "/test.txt"},
)
self.manager.deny_confirmation("req-2")
assert request.status == ConfirmationStatus.DENIED
@@ -146,7 +146,7 @@ class TestPermissionManager:
tool_name="test",
params={},
)
pending = self.manager.get_pending_confirmations()
assert len(pending) >= 1
@@ -157,48 +157,48 @@ class TestToolRegistry:
def test_import_registry(self):
"""测试导入注册中心"""
from minenasai.agent import ToolRegistry, get_tool_registry
registry = get_tool_registry()
assert isinstance(registry, ToolRegistry)
def test_register_builtin_tools(self):
"""测试注册内置工具"""
from minenasai.agent import get_tool_registry, register_builtin_tools
registry = get_tool_registry()
initial_count = len(registry.list_tools())
register_builtin_tools()
# 应该有更多工具
new_count = len(registry.list_tools())
assert new_count >= initial_count
def test_tool_decorator(self):
"""测试工具装饰器"""
from minenasai.agent import tool, get_tool_registry
from minenasai.agent import get_tool_registry, tool
@tool(name="decorated_tool", description="装饰器测试")
async def decorated_tool(param: str) -> dict:
return {"result": param}
registry = get_tool_registry()
tool_obj = registry.get("decorated_tool")
assert tool_obj is not None
assert tool_obj.description == "装饰器测试"
@pytest.mark.asyncio
async def test_execute_tool(self):
"""测试执行工具"""
from minenasai.agent import get_tool_registry, DangerLevel
from minenasai.agent import DangerLevel, get_tool_registry
registry = get_tool_registry()
# 注册测试工具
async def echo(message: str) -> dict:
return {"echo": message}
registry.register(
name="echo",
description="回显消息",
@@ -210,18 +210,18 @@ class TestToolRegistry:
},
danger_level=DangerLevel.SAFE,
)
result = await registry.execute("echo", {"message": "hello"})
assert result["success"] is True
assert result["result"]["echo"] == "hello"
def test_get_stats(self):
"""测试获取统计"""
from minenasai.agent import get_tool_registry
registry = get_tool_registry()
stats = registry.get_stats()
assert "total_tools" in stats
assert "categories" in stats

View File

@@ -15,7 +15,7 @@ class TestCronParser:
def test_parse_all_stars(self):
"""测试全星号表达式"""
result = CronParser.parse("* * * * *")
assert len(result["minute"]) == 60
assert len(result["hour"]) == 24
assert len(result["day"]) == 31
@@ -25,39 +25,39 @@ class TestCronParser:
def test_parse_specific_values(self):
"""测试具体值"""
result = CronParser.parse("30 8 * * *")
assert result["minute"] == {30}
assert result["hour"] == {8}
def test_parse_range(self):
"""测试范围"""
result = CronParser.parse("0 9-17 * * *")
assert result["hour"] == {9, 10, 11, 12, 13, 14, 15, 16, 17}
def test_parse_step(self):
"""测试步进"""
result = CronParser.parse("*/15 * * * *")
assert result["minute"] == {0, 15, 30, 45}
def test_parse_list(self):
"""测试列表"""
result = CronParser.parse("0 8,12,18 * * *")
assert result["hour"] == {8, 12, 18}
def test_parse_preset_daily(self):
"""测试预定义表达式 @daily"""
result = CronParser.parse("@daily")
assert result["minute"] == {0}
assert result["hour"] == {0}
def test_parse_preset_hourly(self):
"""测试预定义表达式 @hourly"""
result = CronParser.parse("@hourly")
assert result["minute"] == {0}
assert len(result["hour"]) == 24
@@ -73,7 +73,7 @@ class TestCronParser:
"0 * * * *",
after=datetime(2026, 1, 1, 10, 30)
)
assert next_run.minute == 0
assert next_run.hour == 11
@@ -89,7 +89,7 @@ class TestCronJob:
schedule="*/5 * * * *",
task="测试",
)
assert job.id == "test-job"
assert job.enabled is True
assert job.last_status == JobStatus.PENDING
@@ -113,7 +113,7 @@ class TestCronScheduler:
schedule="*/5 * * * *",
callback=task,
)
assert job.id == "test-1"
assert job.next_run is not None
@@ -123,7 +123,7 @@ class TestCronScheduler:
pass
self.scheduler.add_job("test-1", "测试", "* * * * *", task)
assert self.scheduler.remove_job("test-1") is True
assert self.scheduler.get_job("test-1") is None
@@ -133,12 +133,12 @@ class TestCronScheduler:
pass
self.scheduler.add_job("test-1", "测试", "* * * * *", task)
assert self.scheduler.disable_job("test-1") is True
job = self.scheduler.get_job("test-1")
assert job.enabled is False
assert job.last_status == JobStatus.DISABLED
assert self.scheduler.enable_job("test-1") is True
assert job.enabled is True
@@ -149,7 +149,7 @@ class TestCronScheduler:
self.scheduler.add_job("test-1", "任务1", "* * * * *", task)
self.scheduler.add_job("test-2", "任务2", "*/5 * * * *", task)
jobs = self.scheduler.list_jobs()
assert len(jobs) == 2
@@ -159,7 +159,7 @@ class TestCronScheduler:
pass
self.scheduler.add_job("test-1", "任务1", "* * * * *", task)
stats = self.scheduler.get_stats()
assert stats["total_jobs"] == 1
assert stats["enabled_jobs"] == 1

View File

@@ -4,8 +4,6 @@ from __future__ import annotations
import time
import pytest
from minenasai.webtui.auth import AuthManager, AuthToken
@@ -45,7 +43,7 @@ class TestAuthManager:
def test_generate_token(self):
"""测试生成令牌"""
token = self.manager.generate_token("user1")
assert token is not None
assert len(token) > 20
@@ -53,7 +51,7 @@ class TestAuthManager:
"""测试验证令牌"""
token = self.manager.generate_token("user1")
auth_token = self.manager.verify_token(token)
assert auth_token is not None
assert auth_token.user_id == "user1"
@@ -65,17 +63,17 @@ class TestAuthManager:
def test_verify_expired_token(self):
"""测试验证过期令牌"""
token = self.manager.generate_token("user1", expires_in=0)
# 等待过期
time.sleep(0.1)
auth_token = self.manager.verify_token(token)
assert auth_token is None
def test_revoke_token(self):
"""测试撤销令牌"""
token = self.manager.generate_token("user1")
assert self.manager.revoke_token(token) is True
assert self.manager.verify_token(token) is None
@@ -88,9 +86,9 @@ class TestAuthManager:
self.manager.generate_token("user1")
self.manager.generate_token("user1")
self.manager.generate_token("user2")
count = self.manager.revoke_user_tokens("user1")
assert count == 2
assert self.manager.get_stats()["total_tokens"] == 1
@@ -98,7 +96,7 @@ class TestAuthManager:
"""测试刷新令牌"""
old_token = self.manager.generate_token("user1")
new_token = self.manager.refresh_token(old_token)
assert new_token is not None
assert new_token != old_token
assert self.manager.verify_token(old_token) is None
@@ -108,9 +106,9 @@ class TestAuthManager:
"""测试令牌元数据"""
metadata = {"channel": "wework", "task_id": "123"}
token = self.manager.generate_token("user1", metadata=metadata)
auth_token = self.manager.verify_token(token)
assert auth_token is not None
assert auth_token.metadata == metadata
@@ -118,10 +116,10 @@ class TestAuthManager:
"""测试清理过期令牌"""
self.manager.generate_token("user1", expires_in=0)
self.manager.generate_token("user2", expires_in=3600)
time.sleep(0.1)
count = self.manager.cleanup_expired()
assert count == 1
assert self.manager.get_stats()["total_tokens"] == 1
@@ -132,17 +130,17 @@ class TestSSHManager:
def test_import_ssh_manager(self):
"""测试导入 SSH 管理器"""
from minenasai.webtui import SSHManager, get_ssh_manager
manager = get_ssh_manager()
assert isinstance(manager, SSHManager)
def test_ssh_manager_stats(self):
"""测试 SSH 管理器统计"""
from minenasai.webtui import SSHManager
manager = SSHManager()
stats = manager.get_stats()
assert "active_sessions" in stats
assert stats["active_sessions"] == 0
@@ -153,6 +151,6 @@ class TestWebTUIServer:
def test_import_server(self):
"""测试导入服务器"""
from minenasai.webtui.server import app
assert app is not None
assert app.title == "MineNASAI Web TUI"